Spaces:
Sleeping
Sleeping
Commit
·
592f71e
0
Parent(s):
init
Browse files- README.md +12 -0
- app.py +277 -0
- models/__init__.py +0 -0
- models/__pycache__/__init__.cpython-37.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/data_utils.cpython-37.pyc +0 -0
- models/__pycache__/dense_retriever.cpython-37.pyc +0 -0
- models/__pycache__/nli.cpython-37.pyc +0 -0
- models/__pycache__/nli.cpython-39.pyc +0 -0
- models/__pycache__/qa_ranker.cpython-37.pyc +0 -0
- models/__pycache__/text_encoder.cpython-37.pyc +0 -0
- models/__pycache__/tokenization.cpython-37.pyc +0 -0
- models/__pycache__/vector_index.cpython-37.pyc +0 -0
- models/data_utils.py +1006 -0
- models/dense_retriever.py +47 -0
- models/encoder/.DS_Store +0 -0
- models/encoder/0_BERT/config.json +20 -0
- models/encoder/0_BERT/sentence_bert_config.json +4 -0
- models/encoder/0_BERT/special_tokens_map.json +1 -0
- models/encoder/0_BERT/tokenizer_config.json +1 -0
- models/encoder/0_BERT/vocab.txt +0 -0
- models/encoder/1_Pooling/config.json +7 -0
- models/encoder/config.json +3 -0
- models/encoder/modules.json +14 -0
- models/encoder/stats.csv +220 -0
- models/nli.py +267 -0
- models/qa_ranker.py +284 -0
- models/sparse_retriever.py +204 -0
- models/sparse_retriever_fast.py +48 -0
- models/text_encoder.py +867 -0
- models/tokenization.py +56 -0
- models/vector_index.py +102 -0
- requirements.txt +30 -0
- utils.py +71 -0
- web_search.py +140 -0
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
This repo is modified from [Quin+](https://github.com/algoprog/Quin).
|
2 |
+
|
3 |
+
To set up:
|
4 |
+
1. Copy `../models` folder into this folder
|
5 |
+
2. Follow instructions in the original repository to download the models.
|
6 |
+
3. Install requirements. The requirements.txt file in the original repo were for Python 3.7, but the requirements in this repo are based on Python 3.8. Choose according to the Python version available on your system.
|
7 |
+
4. Run
|
8 |
+
```shell
|
9 |
+
python quin_web.py
|
10 |
+
```
|
11 |
+
|
12 |
+
To check if the backend is running correctly, set up frontend and enter a query there. By default at `INFO` lever, the logger will record each incoming query.
|
app.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import nltk
|
6 |
+
import numpy
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import requests
|
10 |
+
import torch
|
11 |
+
import urllib
|
12 |
+
import urllib.parse as urlparse
|
13 |
+
|
14 |
+
from bs4 import BeautifulSoup
|
15 |
+
from flask import request, Flask, jsonify
|
16 |
+
from flask_cors import CORS
|
17 |
+
from huggingface_hub import hf_hub_download, snapshot_download, login
|
18 |
+
from models.nli import NLI
|
19 |
+
from models.qa_ranker import PassageRanker
|
20 |
+
from models.text_encoder import SentenceTransformer
|
21 |
+
from multiprocessing.pool import ThreadPool
|
22 |
+
from scipy import spatial
|
23 |
+
from scipy.special import softmax
|
24 |
+
from time import sleep
|
25 |
+
from urllib.parse import parse_qs
|
26 |
+
from web_search import get_html, WebParser, duckduckgo_search
|
27 |
+
|
28 |
+
|
29 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
30 |
+
API_URLS = ['https://letscheck.nus.edu.sg/quin/']
|
31 |
+
logging.getLogger().setLevel(logging.INFO)
|
32 |
+
|
33 |
+
|
34 |
+
def is_question(query):
|
35 |
+
if re.match(r'^(who|when|what|why|which|whose|is|are|was|were|do|does|did|how)', query) or query.endswith('?'):
|
36 |
+
return True
|
37 |
+
pos_tags = nltk.pos_tag(nltk.word_tokenize(query))
|
38 |
+
for tag in pos_tags:
|
39 |
+
if tag[1].startswith('VB'):
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def clear_cache():
|
45 |
+
for url in API_URLS:
|
46 |
+
print(f"Clearing cache from {url}")
|
47 |
+
r = requests.post(f'{url}/clear_cache?secret=123')
|
48 |
+
print(r.status_code)
|
49 |
+
print(r)
|
50 |
+
|
51 |
+
|
52 |
+
class Quin:
|
53 |
+
def __init__(self):
|
54 |
+
nltk.download('punkt')
|
55 |
+
self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
|
56 |
+
self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp'])
|
57 |
+
self.app = Flask(__name__)
|
58 |
+
CORS(self.app)
|
59 |
+
|
60 |
+
snapshot_download(repo_id="anabmaulana/qrbert-multitask", local_dir="cache/qrbert-multitask")
|
61 |
+
hf_hub_download(repo_id="anabmaulana/quin-passage-ranker", filename="passage_ranker.state_dict", local_dir="cache/quin-passage-ranker")
|
62 |
+
hf_hub_download(repo_id="anabmaulana/quin-nli", filename="nli.state_dict", local_dir="cache/quin-nli")
|
63 |
+
|
64 |
+
# torch.cuda.is_available = lambda: False
|
65 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
66 |
+
|
67 |
+
self.text_embedding_model = SentenceTransformer('cache/qrbert-multitask',
|
68 |
+
device=device,
|
69 |
+
parallel=True)
|
70 |
+
self.passage_ranking_model = PassageRanker(
|
71 |
+
model_path='cache/quin-passage-ranker/passage_ranker.state_dict',
|
72 |
+
gpu=torch.cuda.is_available(),
|
73 |
+
batch_size=16)
|
74 |
+
self.passage_ranking_model.eval()
|
75 |
+
self.nli_model = NLI('cache/quin-nli/nli.state_dict',
|
76 |
+
batch_size=64,
|
77 |
+
device=device,
|
78 |
+
parallel=True)
|
79 |
+
self.nli_model.eval()
|
80 |
+
logging.info('Initialized!')
|
81 |
+
|
82 |
+
def extract_snippets(self, text, sentences_per_snippet=4):
|
83 |
+
sentences = self.sent_tokenizer.tokenize(text)
|
84 |
+
snippets = []
|
85 |
+
i = 0
|
86 |
+
last_index = 0
|
87 |
+
while i < len(sentences):
|
88 |
+
snippet = ' '.join(sentences[i:i + sentences_per_snippet])
|
89 |
+
if len(snippet.split(' ')) > 4:
|
90 |
+
snippets.append(snippet)
|
91 |
+
last_index = i + sentences_per_snippet
|
92 |
+
i += sentences_per_snippet
|
93 |
+
if last_index < len(sentences):
|
94 |
+
snippet = ' '.join(sentences[last_index:])
|
95 |
+
if len(snippet.split(' ')) > 4:
|
96 |
+
snippets.append(snippet)
|
97 |
+
return snippets
|
98 |
+
|
99 |
+
def search_web_evidence(self, query, limit=10):
|
100 |
+
logging.info('searching the web...')
|
101 |
+
urls = duckduckgo_search(query, pages=1)[:limit]
|
102 |
+
logging.info('downloading {} web pages...'.format(len(urls)))
|
103 |
+
search_results = []
|
104 |
+
# Move this out of the function, since it is used a few times
|
105 |
+
query_is_question = is_question(query)
|
106 |
+
|
107 |
+
# Local methods for multithreading
|
108 |
+
def download(url):
|
109 |
+
nonlocal search_results
|
110 |
+
data = get_html(url)
|
111 |
+
soup = BeautifulSoup(data, features='lxml')
|
112 |
+
title = soup.title.string
|
113 |
+
w = WebParser()
|
114 |
+
w.feed(data)
|
115 |
+
new_snippets = sum([self.extract_snippets(b) for b in w.blocks if b.count(' ') > 20], [])
|
116 |
+
new_snippets = [{'snippet': snippet, 'url': url, 'title': title} for snippet in new_snippets]
|
117 |
+
search_results += new_snippets
|
118 |
+
|
119 |
+
def timeout_download(arg):
|
120 |
+
pool = ThreadPool(1)
|
121 |
+
try:
|
122 |
+
pool.apply_async(download, [arg]).get(timeout=2)
|
123 |
+
except:
|
124 |
+
pass
|
125 |
+
pool.close()
|
126 |
+
pool.join()
|
127 |
+
|
128 |
+
p = ThreadPool(32)
|
129 |
+
p.map(timeout_download, urls)
|
130 |
+
p.close()
|
131 |
+
p.join()
|
132 |
+
|
133 |
+
if query_is_question:
|
134 |
+
logging.info('re-ranking...')
|
135 |
+
snippets = [s['snippet'] for s in search_results]
|
136 |
+
qa_pairs = [(query, snippet) for snippet in snippets]
|
137 |
+
_, probs = self.passage_ranking_model(qa_pairs)
|
138 |
+
probs = [softmax(p)[1] for p in probs]
|
139 |
+
filtered_results = []
|
140 |
+
for i in range(len(search_results)):
|
141 |
+
if probs[i] > 0.35:
|
142 |
+
search_results[i]['score'] = str(probs[i])
|
143 |
+
filtered_results.append(search_results[i])
|
144 |
+
search_results = filtered_results
|
145 |
+
search_results = sorted(search_results, key=lambda x: float(x['score']), reverse=True)
|
146 |
+
|
147 |
+
# Split into a different function below
|
148 |
+
# highlight_results(self, search_results)
|
149 |
+
# highlight most relevant sentences
|
150 |
+
logging.info('highlighting...')
|
151 |
+
results_sentences = []
|
152 |
+
sentences_texts = []
|
153 |
+
sentences_vectors = {}
|
154 |
+
for i, r in enumerate(search_results):
|
155 |
+
sentences = self.sent_tokenizer.tokenize(r['snippet'])
|
156 |
+
sentences = [s for s in sentences if len(s.split(' ')) > 4]
|
157 |
+
sentences_texts.extend(sentences)
|
158 |
+
results_sentences.append(sentences)
|
159 |
+
|
160 |
+
vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128)
|
161 |
+
for i, v in enumerate(vectors):
|
162 |
+
sentences_vectors[sentences_texts[i]] = v
|
163 |
+
|
164 |
+
query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0]
|
165 |
+
for i, sentences in enumerate(results_sentences):
|
166 |
+
best_sentences = set()
|
167 |
+
evidence_sentences = []
|
168 |
+
for sentence in sentences:
|
169 |
+
sentence_vector = sentences_vectors[sentence]
|
170 |
+
score = 1 - spatial.distance.cosine(query_vector, sentence_vector)
|
171 |
+
if score > 0.91:
|
172 |
+
best_sentences.add(sentence)
|
173 |
+
evidence_sentences.append(sentence)
|
174 |
+
if len(evidence_sentences) > 0:
|
175 |
+
search_results[i]['evidence'] = ' '.join(evidence_sentences)
|
176 |
+
search_results[i]['snippet'] = \
|
177 |
+
' '.join([s if s not in best_sentences else '<b>{}</b>'.format(s) for s in sentences])
|
178 |
+
|
179 |
+
search_results = [s for s in search_results if 'evidence' in s]
|
180 |
+
|
181 |
+
# Split into a different function below
|
182 |
+
# entailment_classification(self, search_results)
|
183 |
+
if not query_is_question:
|
184 |
+
logging.info('entailment classification...')
|
185 |
+
es_pairs = []
|
186 |
+
for result in search_results:
|
187 |
+
evidence = result['evidence']
|
188 |
+
es_pairs.append((evidence, query))
|
189 |
+
try:
|
190 |
+
labels, probs = self.nli_model(es_pairs)
|
191 |
+
logging.info(str(labels))
|
192 |
+
for i in range(len(labels)):
|
193 |
+
confidence = numpy.exp(numpy.max(probs[i]))
|
194 |
+
if confidence > 0.4:
|
195 |
+
search_results[i]['nli_class'] = labels[i]
|
196 |
+
else:
|
197 |
+
search_results[i]['nli_class'] = 'neutral'
|
198 |
+
search_results[i]['nli_confidence'] = str(confidence)
|
199 |
+
except Exception as e:
|
200 |
+
logging.warning('Error doing entailment classification')
|
201 |
+
logging.warning(str(e))
|
202 |
+
try:
|
203 |
+
search_results = [s for s in search_results if s['nli_class'] != 'neutral']
|
204 |
+
except:
|
205 |
+
pass
|
206 |
+
|
207 |
+
filtered_results = []
|
208 |
+
added_urls = set()
|
209 |
+
supporting = 0
|
210 |
+
refuting = 0
|
211 |
+
for r in search_results:
|
212 |
+
if r['url'] not in added_urls:
|
213 |
+
filtered_results.append(r)
|
214 |
+
added_urls.add(r['url'])
|
215 |
+
if 'nli_class' in r:
|
216 |
+
if r['nli_class'] == 'entailment':
|
217 |
+
supporting += 1
|
218 |
+
elif r['nli_class'] == 'contradiction':
|
219 |
+
refuting += 1
|
220 |
+
search_results = filtered_results
|
221 |
+
|
222 |
+
if supporting > 4 or refuting > 4:
|
223 |
+
if supporting / (refuting + 0.001) > 1.7:
|
224 |
+
veracity = 'Probably True'
|
225 |
+
elif refuting / (supporting + 0.001) > 1.7:
|
226 |
+
veracity = 'Probably False'
|
227 |
+
else:
|
228 |
+
veracity = '? Ambiguous'
|
229 |
+
else:
|
230 |
+
veracity = 'Not enough evidence'
|
231 |
+
|
232 |
+
search_results = search_results[:limit]
|
233 |
+
|
234 |
+
logging.info('done searching')
|
235 |
+
|
236 |
+
if query_is_question:
|
237 |
+
return {'type': 'question',
|
238 |
+
'results': search_results}
|
239 |
+
else:
|
240 |
+
return {'type': 'statement',
|
241 |
+
'supporting': supporting,
|
242 |
+
'refuting': refuting,
|
243 |
+
'veracity_rating': veracity,
|
244 |
+
'results': search_results}
|
245 |
+
|
246 |
+
def build_endpoints(self):
|
247 |
+
@self.app.route('/search', methods=['POST', 'GET'])
|
248 |
+
def search_endpoint():
|
249 |
+
query = request.args.get('query').lower()
|
250 |
+
limit = request.args.get('limit') or 100
|
251 |
+
limit = min(int(limit), 100)
|
252 |
+
results = json.dumps(self.search_web_evidence(query, limit=limit), indent=4)
|
253 |
+
return results
|
254 |
+
|
255 |
+
def serve(self, port=12345):
|
256 |
+
self.build_endpoints()
|
257 |
+
self.app.run(host='0.0.0.0', port=port)
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
fchecker = Quin()
|
262 |
+
|
263 |
+
def predict(query, limit):
|
264 |
+
return fchecker.search_web_evidence(query, limit=limit)
|
265 |
+
|
266 |
+
demo = gr.Interface(
|
267 |
+
fn=predict,
|
268 |
+
inputs=[
|
269 |
+
gr.Textbox(label="Enter Query"), # string input
|
270 |
+
gr.Number(label="Enter Limit") # number input
|
271 |
+
],
|
272 |
+
outputs=gr.JSON(label="Output")
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == '__main__':
|
277 |
+
demo.launch()
|
models/__init__.py
ADDED
File without changes
|
models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (157 Bytes). View file
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (128 Bytes). View file
|
|
models/__pycache__/data_utils.cpython-37.pyc
ADDED
Binary file (23.7 kB). View file
|
|
models/__pycache__/dense_retriever.cpython-37.pyc
ADDED
Binary file (2.42 kB). View file
|
|
models/__pycache__/nli.cpython-37.pyc
ADDED
Binary file (8.4 kB). View file
|
|
models/__pycache__/nli.cpython-39.pyc
ADDED
Binary file (8.38 kB). View file
|
|
models/__pycache__/qa_ranker.cpython-37.pyc
ADDED
Binary file (9.03 kB). View file
|
|
models/__pycache__/text_encoder.cpython-37.pyc
ADDED
Binary file (29.5 kB). View file
|
|
models/__pycache__/tokenization.cpython-37.pyc
ADDED
Binary file (1.23 kB). View file
|
|
models/__pycache__/vector_index.cpython-37.pyc
ADDED
Binary file (3.42 kB). View file
|
|
models/data_utils.py
ADDED
@@ -0,0 +1,1006 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import re
|
7 |
+
from pprint import pprint
|
8 |
+
|
9 |
+
import pandas
|
10 |
+
import gzip
|
11 |
+
import os
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from random import randint, random
|
16 |
+
from tqdm import tqdm
|
17 |
+
from typing import List, Dict
|
18 |
+
|
19 |
+
from spacy.lang.en import English
|
20 |
+
# from tensorflow.keras.utils import Sequence
|
21 |
+
from transformers import RobertaTokenizer, BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
22 |
+
|
23 |
+
# from utils import WebParser
|
24 |
+
from models.dense_retriever import DenseRetriever
|
25 |
+
from models.tokenization import tokenize
|
26 |
+
|
27 |
+
from typing import Union, List
|
28 |
+
|
29 |
+
from unidecode import unidecode
|
30 |
+
|
31 |
+
import spacy
|
32 |
+
|
33 |
+
|
34 |
+
# from utils import WebParser
|
35 |
+
|
36 |
+
|
37 |
+
class InputExample:
|
38 |
+
"""
|
39 |
+
Structure for one input example with texts, the label and a unique id
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, guid: str, texts: List[str], label: Union[int, float]):
|
43 |
+
"""
|
44 |
+
Creates one InputExample with the given texts, guid and label
|
45 |
+
str.strip() is called on both texts.
|
46 |
+
:param guid
|
47 |
+
id for the example
|
48 |
+
:param texts
|
49 |
+
the texts for the example
|
50 |
+
:param label
|
51 |
+
the label for the example
|
52 |
+
"""
|
53 |
+
self.guid = guid
|
54 |
+
self.texts = [text.strip() for text in texts]
|
55 |
+
self.label = label
|
56 |
+
|
57 |
+
def get_texts(self):
|
58 |
+
return self.texts
|
59 |
+
|
60 |
+
def get_label(self):
|
61 |
+
return self.label
|
62 |
+
|
63 |
+
|
64 |
+
class LoggingHandler(logging.Handler):
|
65 |
+
def __init__(self, level=logging.NOTSET):
|
66 |
+
super().__init__(level)
|
67 |
+
|
68 |
+
def emit(self, record):
|
69 |
+
try:
|
70 |
+
msg = self.format(record)
|
71 |
+
tqdm.write(msg)
|
72 |
+
self.flush()
|
73 |
+
except (KeyboardInterrupt, SystemExit):
|
74 |
+
raise
|
75 |
+
except:
|
76 |
+
self.handleError(record)
|
77 |
+
|
78 |
+
|
79 |
+
def get_examples(filename, max_examples=0):
|
80 |
+
examples = []
|
81 |
+
id = 0
|
82 |
+
with open(filename, encoding='utf8') as file:
|
83 |
+
for j, line in enumerate(file):
|
84 |
+
line = line.rstrip('\n')
|
85 |
+
sample = json.loads(line)
|
86 |
+
label = sample['label']
|
87 |
+
guid = "%s-%d" % (filename, id)
|
88 |
+
id += 1
|
89 |
+
if label == 'entailment':
|
90 |
+
label = 0
|
91 |
+
elif label == 'contradiction':
|
92 |
+
label = 1
|
93 |
+
else:
|
94 |
+
label = 2
|
95 |
+
examples.append(InputExample(guid=guid,
|
96 |
+
texts=[sample['s1'], sample['s2']],
|
97 |
+
label=label))
|
98 |
+
if 0 < max_examples <= len(examples):
|
99 |
+
break
|
100 |
+
return examples
|
101 |
+
|
102 |
+
|
103 |
+
def get_qa_examples(filename, max_examples=0, dev=False):
|
104 |
+
examples = []
|
105 |
+
id = 0
|
106 |
+
with open(filename, encoding='utf8') as file:
|
107 |
+
for j, line in enumerate(file):
|
108 |
+
line = line.rstrip('\n')
|
109 |
+
sample = json.loads(line)
|
110 |
+
label = sample['relevant']
|
111 |
+
guid = "%s-%d" % (filename, id)
|
112 |
+
id += 1
|
113 |
+
examples.append(InputExample(guid=guid,
|
114 |
+
texts=[sample['question'], sample['answer']],
|
115 |
+
label=label))
|
116 |
+
if not dev:
|
117 |
+
if label == 1:
|
118 |
+
for _ in range(13):
|
119 |
+
examples.append(InputExample(guid=guid,
|
120 |
+
texts=[sample['question'], sample['answer']],
|
121 |
+
label=label))
|
122 |
+
if 0 < max_examples <= len(examples):
|
123 |
+
break
|
124 |
+
return examples
|
125 |
+
|
126 |
+
|
127 |
+
def map_label(label):
|
128 |
+
labels = {"relevant": 0, "irrelevant": 1}
|
129 |
+
return labels[label.strip().lower()]
|
130 |
+
|
131 |
+
|
132 |
+
def get_qar_examples(filename, max_examples=0):
|
133 |
+
examples = []
|
134 |
+
id = 0
|
135 |
+
with open(filename, encoding='utf8') as file:
|
136 |
+
for j, line in enumerate(file):
|
137 |
+
line = line.rstrip('\n')
|
138 |
+
sample = json.loads(line)
|
139 |
+
guid = "%s-%d" % (filename, id)
|
140 |
+
id += 1
|
141 |
+
examples.append(InputExample(guid=guid,
|
142 |
+
texts=[sample['question'], sample['answer']],
|
143 |
+
label=1.0))
|
144 |
+
if 0 < max_examples <= len(examples):
|
145 |
+
break
|
146 |
+
return examples
|
147 |
+
|
148 |
+
|
149 |
+
def get_qar_artificial_examples():
|
150 |
+
examples = []
|
151 |
+
id = 0
|
152 |
+
|
153 |
+
print('Loading passages...')
|
154 |
+
|
155 |
+
passages = []
|
156 |
+
file = open('data/msmarco/collection.tsv', 'r', encoding='utf8')
|
157 |
+
while True:
|
158 |
+
line = file.readline()
|
159 |
+
if not line:
|
160 |
+
break
|
161 |
+
line = line.rstrip('\n').split('\t')
|
162 |
+
passages.append(line[1])
|
163 |
+
|
164 |
+
print('Loaded passages')
|
165 |
+
|
166 |
+
with open('data/qar/qar_artificial_queries.csv') as f:
|
167 |
+
for i, line in enumerate(f):
|
168 |
+
queries = line.rstrip('\n').split('|')
|
169 |
+
for query in queries:
|
170 |
+
guid = "%s-%d" % ('', id)
|
171 |
+
id += 1
|
172 |
+
examples.append(InputExample(guid=guid,
|
173 |
+
texts=[query, passages[i]],
|
174 |
+
label=1.0))
|
175 |
+
return examples
|
176 |
+
|
177 |
+
|
178 |
+
def get_single_examples(filename, max_examples=0):
|
179 |
+
examples = []
|
180 |
+
id = 0
|
181 |
+
with open(filename, encoding='utf8') as file:
|
182 |
+
for j, line in enumerate(file):
|
183 |
+
line = line.rstrip('\n')
|
184 |
+
sample = json.loads(line)
|
185 |
+
guid = "%s-%d" % (filename, id)
|
186 |
+
id += 1
|
187 |
+
examples.append(InputExample(guid=guid,
|
188 |
+
texts=[sample['text']],
|
189 |
+
label=1))
|
190 |
+
if 0 < max_examples <= len(examples):
|
191 |
+
break
|
192 |
+
return examples
|
193 |
+
|
194 |
+
|
195 |
+
def get_qnli_examples(filename, max_examples=0, no_contradictions=False, fever_only=False):
|
196 |
+
examples = []
|
197 |
+
id = 0
|
198 |
+
with open(filename, encoding='utf8') as file:
|
199 |
+
for j, line in enumerate(file):
|
200 |
+
line = line.rstrip('\n')
|
201 |
+
sample = json.loads(line)
|
202 |
+
label = sample['label']
|
203 |
+
if label == 'contradiction' and no_contradictions:
|
204 |
+
continue
|
205 |
+
if sample['evidence'] == '':
|
206 |
+
continue
|
207 |
+
if fever_only and sample['source'] != 'fever':
|
208 |
+
continue
|
209 |
+
guid = "%s-%d" % (filename, id)
|
210 |
+
id += 1
|
211 |
+
|
212 |
+
examples.append(InputExample(guid=guid,
|
213 |
+
texts=[sample['statement'].strip(), sample['evidence'].strip()],
|
214 |
+
label=1.0))
|
215 |
+
if 0 < max_examples <= len(examples):
|
216 |
+
break
|
217 |
+
return examples
|
218 |
+
|
219 |
+
|
220 |
+
def get_retrieval_examples(filename, negative_corpus='data/msmarco/collection.tsv', max_examples=0, no_statements=True,
|
221 |
+
encoder_model=None, negative_samples_num=4):
|
222 |
+
examples = []
|
223 |
+
queries = []
|
224 |
+
passages = []
|
225 |
+
negative_passages = []
|
226 |
+
id = 0
|
227 |
+
with open(filename, encoding='utf8') as file:
|
228 |
+
for j, line in enumerate(file):
|
229 |
+
line = line.rstrip('\n')
|
230 |
+
sample = json.loads(line)
|
231 |
+
|
232 |
+
if 'evidence' in sample and sample['evidence'] == '':
|
233 |
+
continue
|
234 |
+
|
235 |
+
guid = "%s-%d" % (filename, id)
|
236 |
+
id += 1
|
237 |
+
|
238 |
+
if sample['type'] == 'question':
|
239 |
+
query = sample['question']
|
240 |
+
passage = sample['answer']
|
241 |
+
else:
|
242 |
+
query = sample['statement']
|
243 |
+
passage = sample['evidence']
|
244 |
+
|
245 |
+
query = query.strip()
|
246 |
+
passage = passage.strip()
|
247 |
+
|
248 |
+
if sample['type'] == 'statement' and no_statements:
|
249 |
+
continue
|
250 |
+
|
251 |
+
queries.append(query)
|
252 |
+
passages.append(passage)
|
253 |
+
|
254 |
+
if sample['source'] == 'natural-questions':
|
255 |
+
negative_passages.append(passage)
|
256 |
+
|
257 |
+
if max_examples == len(passages):
|
258 |
+
break
|
259 |
+
|
260 |
+
if encoder_model is not None:
|
261 |
+
# Load MSMARCO passages
|
262 |
+
logging.info('Loading MSM passages...')
|
263 |
+
with open(negative_corpus) as file:
|
264 |
+
for line in file:
|
265 |
+
p = line.rstrip('\n').split('\t')[1]
|
266 |
+
negative_passages.append(p)
|
267 |
+
|
268 |
+
logging.info('Building ANN index...')
|
269 |
+
dense_retriever = DenseRetriever(model=encoder_model, batch_size=1024, use_gpu=True)
|
270 |
+
dense_retriever.create_index_from_documents(negative_passages)
|
271 |
+
results = dense_retriever.search(queries=queries, limit=100, probes=256)
|
272 |
+
negative_samples = [
|
273 |
+
[negative_passages[p[0]] for p in r if negative_passages[p[0]] != passages[i]][:negative_samples_num]
|
274 |
+
for i, r in enumerate(results)
|
275 |
+
]
|
276 |
+
# print(queries[0])
|
277 |
+
# print(negative_samples[0][0])
|
278 |
+
|
279 |
+
for i in range(len(queries)):
|
280 |
+
texts = [queries[i], passages[i]] + negative_samples[i]
|
281 |
+
examples.append(InputExample(guid=guid,
|
282 |
+
texts=texts,
|
283 |
+
label=1.0))
|
284 |
+
|
285 |
+
else:
|
286 |
+
for i in range(len(queries)):
|
287 |
+
texts = [queries[i], passages[i]]
|
288 |
+
examples.append(InputExample(guid=guid,
|
289 |
+
texts=texts,
|
290 |
+
label=1.0))
|
291 |
+
|
292 |
+
return examples
|
293 |
+
|
294 |
+
|
295 |
+
def get_ict_examples(filename, max_examples=0):
|
296 |
+
examples = []
|
297 |
+
id = 0
|
298 |
+
with open(filename, encoding='utf8') as file:
|
299 |
+
for j, line in enumerate(file):
|
300 |
+
line = line.rstrip('\n')
|
301 |
+
sample = json.loads(line)
|
302 |
+
# label = sample['label']
|
303 |
+
guid = "%s-%d" % (filename, id)
|
304 |
+
id += 1
|
305 |
+
examples.append(InputExample(guid=guid,
|
306 |
+
texts=[sample['s1'].strip(), sample['s2'].strip()],
|
307 |
+
label=1.0))
|
308 |
+
if 0 < max_examples <= len(examples):
|
309 |
+
break
|
310 |
+
return examples
|
311 |
+
|
312 |
+
|
313 |
+
def preprocess_fever_dataset(path):
|
314 |
+
def replace_symbols(text):
|
315 |
+
return text.replace('-LRB-', '(') \
|
316 |
+
.replace('-RRB-', ')') \
|
317 |
+
.replace('-LSB-', '[') \
|
318 |
+
.replace('-RSB-', ']'). \
|
319 |
+
replace('-LCB-', '{'). \
|
320 |
+
replace('-RCB-', '}')
|
321 |
+
|
322 |
+
print('Loading wiki articles...')
|
323 |
+
articles = {}
|
324 |
+
for i in range(109):
|
325 |
+
wid = str(i + 1).zfill(3)
|
326 |
+
file = open(path + '/wiki-pages/wiki-' + wid + '.jsonl', 'r', encoding='utf8')
|
327 |
+
for line in file:
|
328 |
+
data = json.loads(line.rstrip('\n'))
|
329 |
+
lines = data['lines'].split('\n')
|
330 |
+
plines = []
|
331 |
+
for line in lines:
|
332 |
+
slines = line.split('\t')
|
333 |
+
if len(slines) > 1 and slines[0].isnumeric():
|
334 |
+
plines.append(slines[1])
|
335 |
+
else:
|
336 |
+
plines.append('')
|
337 |
+
|
338 |
+
lines = plines
|
339 |
+
data['id'] = unidecode(data['id'])
|
340 |
+
articles[data['id']] = lines
|
341 |
+
|
342 |
+
print('Preprocessing dataset...')
|
343 |
+
files = ['train', 'dev']
|
344 |
+
for file in files:
|
345 |
+
fo = open(path + '/facts_{}.jsonl'.format(file), 'w+', encoding='utf8')
|
346 |
+
with open(path + '/{}.jsonl'.format(file), encoding='utf8') as f:
|
347 |
+
for line in f:
|
348 |
+
data = json.loads(line.rstrip('\n'))
|
349 |
+
claim = data['claim']
|
350 |
+
label = data['label']
|
351 |
+
evidence = []
|
352 |
+
sents = []
|
353 |
+
if label != 'NOT ENOUGH INFO':
|
354 |
+
for evidence_list in data['evidence']:
|
355 |
+
evidence_sents = []
|
356 |
+
extra_left = []
|
357 |
+
extra_right = []
|
358 |
+
evidence_words = 0
|
359 |
+
for e in evidence_list:
|
360 |
+
e[2] = unidecode(e[2])
|
361 |
+
if e[2] in articles and len(articles[e[2]]) > e[3]:
|
362 |
+
eid = e[3]
|
363 |
+
article = articles[e[2]]
|
364 |
+
evidence_sents.append(replace_symbols(article[eid].replace(' ', ' ').strip()))
|
365 |
+
evidence_words += len(article[eid].split(' '))
|
366 |
+
left = []
|
367 |
+
for i in range(0, eid):
|
368 |
+
left.append(replace_symbols(article[i]))
|
369 |
+
extra_left.append(left)
|
370 |
+
right = []
|
371 |
+
for i in range(eid + 1, len(article)):
|
372 |
+
right.append(replace_symbols(article[i]))
|
373 |
+
extra_right.append(right)
|
374 |
+
else:
|
375 |
+
evidence_sents = []
|
376 |
+
break
|
377 |
+
evidence_text = []
|
378 |
+
for i in range(len(evidence_sents)):
|
379 |
+
for j in range(len(extra_left[i])):
|
380 |
+
xlen = len(extra_left[i][j].split(' '))
|
381 |
+
if evidence_words + xlen < 254:
|
382 |
+
evidence_text.append(extra_left[i][j])
|
383 |
+
evidence_words += xlen
|
384 |
+
evidence_text.append(evidence_sents[i])
|
385 |
+
for j in range(len(extra_right[i])):
|
386 |
+
xlen = len(extra_right[i][j].split(' '))
|
387 |
+
if evidence_words + xlen < 254:
|
388 |
+
evidence_text.append(extra_right[i][j])
|
389 |
+
evidence_words += xlen
|
390 |
+
|
391 |
+
if len(evidence_text) > 0:
|
392 |
+
evidence_text = unidecode(' '.join(evidence_text))
|
393 |
+
evidence_text = ' '.join(evidence_text.split()).strip()
|
394 |
+
evidence.append(evidence_text)
|
395 |
+
sents.append(' '.join(evidence_sents))
|
396 |
+
|
397 |
+
if label == 'NOT ENOUGH INFO' or len(evidence) > 0:
|
398 |
+
fo.write(json.dumps({
|
399 |
+
'statement': unidecode(replace_symbols(claim)),
|
400 |
+
'label': label,
|
401 |
+
'evidence_long': list(set(evidence)),
|
402 |
+
'evidence': list(set(sents))
|
403 |
+
}) + '\n')
|
404 |
+
fo.close()
|
405 |
+
|
406 |
+
|
407 |
+
def preprocess_squad_dataset(path, output):
|
408 |
+
# Read dataset
|
409 |
+
with open(path) as f:
|
410 |
+
dataset = json.load(f)
|
411 |
+
|
412 |
+
# Iterate and write question-answer pairs
|
413 |
+
with open(output, 'w+') as f:
|
414 |
+
for article in dataset['data']:
|
415 |
+
for paragraph in article['paragraphs']:
|
416 |
+
context = paragraph['context']
|
417 |
+
for qa in paragraph['qas']:
|
418 |
+
question = qa['question']
|
419 |
+
answers = [a['text'] for a in qa['answers']]
|
420 |
+
min_answer = ''
|
421 |
+
min_length = 1000
|
422 |
+
for answer in answers:
|
423 |
+
if len(answer) < min_length:
|
424 |
+
min_answer = answer
|
425 |
+
min_length = len(answer)
|
426 |
+
if min_answer == '':
|
427 |
+
continue
|
428 |
+
f.write(json.dumps({'question': question, 'answer': min_answer, 'evidence': context}))
|
429 |
+
f.write('\n')
|
430 |
+
|
431 |
+
|
432 |
+
def preprocess_nq_dataset():
|
433 |
+
def _clean_token(token):
|
434 |
+
return re.sub(u" ", "_", token["token"])
|
435 |
+
|
436 |
+
def get_text_blocks(html):
|
437 |
+
w = WebParser()
|
438 |
+
html = '<div>{}</div>'.format(html)
|
439 |
+
w.feed(html)
|
440 |
+
blocks = w.get_blocks()
|
441 |
+
return blocks
|
442 |
+
|
443 |
+
files = ['train']
|
444 |
+
for file in files:
|
445 |
+
of = open('../data/qa/nq/{}_all_h.jsonl'.format(file), 'w+')
|
446 |
+
with open('../data/qa/nq/{}.jsonl'.format(file), encoding='utf8') as f:
|
447 |
+
for line in tqdm(f):
|
448 |
+
data = json.loads(line.rstrip('\n'))
|
449 |
+
question = data['question_text']
|
450 |
+
|
451 |
+
if file == 'dev':
|
452 |
+
data['document_text'] = " ".join([_clean_token(t) for t in data["document_tokens"]])
|
453 |
+
doc_tokens = data['document_text'].replace('*', '').split(' ')
|
454 |
+
|
455 |
+
short_answer = ''
|
456 |
+
if len(data['annotations'][0]['short_answers']) > 0:
|
457 |
+
short_answer_info = data['annotations'][0]['short_answers'][0]
|
458 |
+
doc_tokens.insert(short_answer_info['start_token'], '*')
|
459 |
+
doc_tokens.insert(short_answer_info['end_token'] + 1, '*')
|
460 |
+
short_answer = ' '.join(
|
461 |
+
doc_tokens[short_answer_info['start_token'] + 1:short_answer_info['end_token'] + 1])
|
462 |
+
short_answer = ' '.join(get_text_blocks(short_answer))
|
463 |
+
|
464 |
+
long_answer_info = data['annotations'][0]['long_answer']
|
465 |
+
long_answer = ' '.join(doc_tokens[long_answer_info['start_token']:long_answer_info['end_token']])
|
466 |
+
long_answer = ' '.join(get_text_blocks(long_answer))
|
467 |
+
|
468 |
+
if long_answer == '':
|
469 |
+
continue
|
470 |
+
|
471 |
+
example = {
|
472 |
+
'question': question,
|
473 |
+
'answer': short_answer,
|
474 |
+
'evidence': long_answer
|
475 |
+
}
|
476 |
+
|
477 |
+
of.write(json.dumps(example) + '\n')
|
478 |
+
|
479 |
+
|
480 |
+
def preprocess_msmarco():
|
481 |
+
files = ['train', 'dev']
|
482 |
+
for file in files:
|
483 |
+
fo = open('../data/qa/msmarco/{}_all.jsonl'.format(file), 'w+')
|
484 |
+
with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
|
485 |
+
for line in f:
|
486 |
+
data = json.loads(line.rstrip('\n'))
|
487 |
+
q = data['query']
|
488 |
+
a = data['answers'][0]
|
489 |
+
evidence = ''
|
490 |
+
for passage in data['passages']:
|
491 |
+
if passage['is_selected'] == 1:
|
492 |
+
evidence = passage['passage_text']
|
493 |
+
break
|
494 |
+
|
495 |
+
if '_' in q:
|
496 |
+
continue
|
497 |
+
|
498 |
+
example = {
|
499 |
+
'question': q,
|
500 |
+
'answer': a,
|
501 |
+
'evidence': evidence
|
502 |
+
}
|
503 |
+
|
504 |
+
fo.write(json.dumps(example) + '\n')
|
505 |
+
|
506 |
+
|
507 |
+
def create_qa_ranking_dataset():
|
508 |
+
files = ['train', 'dev']
|
509 |
+
for file in files:
|
510 |
+
relevant = 0
|
511 |
+
irrelevant = 0
|
512 |
+
max_irrelevant = 10 if file == 'train' else 1
|
513 |
+
fo = open('../data/qa_ranking_large/{}.jsonl'.format(file), 'w+')
|
514 |
+
with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
|
515 |
+
for line in f:
|
516 |
+
data = json.loads(line.rstrip('\n'))
|
517 |
+
added_irrelevant = 0
|
518 |
+
for passage in data['passages']:
|
519 |
+
example = {
|
520 |
+
'question': data['query'],
|
521 |
+
'answer': passage['passage_text'],
|
522 |
+
'relevant': passage['is_selected']
|
523 |
+
}
|
524 |
+
if passage['is_selected'] == 1 or added_irrelevant < max_irrelevant:
|
525 |
+
fo.write(json.dumps(example) + '\n')
|
526 |
+
if passage['is_selected'] == 1:
|
527 |
+
relevant += 1
|
528 |
+
else:
|
529 |
+
added_irrelevant += 1
|
530 |
+
irrelevant += 1
|
531 |
+
|
532 |
+
print(relevant, irrelevant)
|
533 |
+
|
534 |
+
|
535 |
+
def preprocess_qa_nli_dataset(path):
|
536 |
+
files = ['dev', 'train']
|
537 |
+
for file in files:
|
538 |
+
of = open(path + '/' + file + '_.tsv', 'w+')
|
539 |
+
with open(path + '/' + file + '.tsv', 'r') as f:
|
540 |
+
for line in f:
|
541 |
+
line = line.rstrip()
|
542 |
+
s = line.split('\t')
|
543 |
+
if s[2].endswith('.'):
|
544 |
+
s[2] = s[2][:len(s[2]) - 1] + '?'
|
545 |
+
if not s[2].endswith('?'):
|
546 |
+
s[2] += ' ?'
|
547 |
+
input_str = s[2] + ' ' + s[3]
|
548 |
+
output_str = s[4]
|
549 |
+
of.write(input_str + '\t' + output_str + '\n')
|
550 |
+
of.close()
|
551 |
+
|
552 |
+
|
553 |
+
def get_pair_input(tokenizer, sent1, sent2, max_len=256):
|
554 |
+
text = "[CLS] {} [SEP] {} [SEP]".format(sent1, sent2)
|
555 |
+
|
556 |
+
tokenized_text = tokenizer.tokenize(text)[:max_len]
|
557 |
+
indexed_tokens = tokenizer.encode(text)[:max_len]
|
558 |
+
|
559 |
+
segments_ids = []
|
560 |
+
sep_flag = False
|
561 |
+
for i in range(len(tokenized_text)):
|
562 |
+
if tokenized_text[i] == '[SEP]' and not sep_flag:
|
563 |
+
segments_ids.append(0)
|
564 |
+
sep_flag = True
|
565 |
+
elif sep_flag:
|
566 |
+
segments_ids.append(1)
|
567 |
+
else:
|
568 |
+
segments_ids.append(0)
|
569 |
+
return indexed_tokens, segments_ids
|
570 |
+
|
571 |
+
|
572 |
+
def build_batch(tokenizer, text_list, max_len=256):
|
573 |
+
token_id_list = []
|
574 |
+
segment_list = []
|
575 |
+
attention_masks = []
|
576 |
+
longest = -1
|
577 |
+
|
578 |
+
for pair in text_list:
|
579 |
+
sent1, sent2 = pair
|
580 |
+
ids, segs = get_pair_input(tokenizer, sent1, sent2, max_len=max_len)
|
581 |
+
if ids is None or segs is None:
|
582 |
+
continue
|
583 |
+
token_id_list.append(ids)
|
584 |
+
segment_list.append(segs)
|
585 |
+
attention_masks.append([1] * len(ids))
|
586 |
+
if len(ids) > longest:
|
587 |
+
longest = len(ids)
|
588 |
+
|
589 |
+
if len(token_id_list) == 0:
|
590 |
+
return None, None, None
|
591 |
+
|
592 |
+
# padding
|
593 |
+
assert (len(token_id_list) == len(segment_list))
|
594 |
+
for ii in range(len(token_id_list)):
|
595 |
+
token_id_list[ii] += [0] * (longest - len(token_id_list[ii]))
|
596 |
+
attention_masks[ii] += [1] * (longest - len(attention_masks[ii]))
|
597 |
+
segment_list[ii] += [1] * (longest - len(segment_list[ii]))
|
598 |
+
|
599 |
+
return token_id_list, segment_list, attention_masks
|
600 |
+
|
601 |
+
|
602 |
+
"""
|
603 |
+
class QAEncoderBatchGenerator(Sequence):
|
604 |
+
def __init__(self, ids_to_paragraphs: List, ids_to_queries: Dict, qrels_train: List, tokenizer,
|
605 |
+
max_question_length=64, max_answer_length=512, batch_size=64):
|
606 |
+
np.random.seed(42)
|
607 |
+
self.ids_to_paragraphs = ids_to_paragraphs
|
608 |
+
self.ids_to_queries = ids_to_queries
|
609 |
+
self.qrels_train = qrels_train
|
610 |
+
self.tokenizer = tokenizer
|
611 |
+
self.max_question_length = max_question_length
|
612 |
+
self.max_answer_length = max_answer_length
|
613 |
+
self.batch_size = batch_size
|
614 |
+
self.on_epoch_end()
|
615 |
+
|
616 |
+
def __len__(self):
|
617 |
+
return int(np.floor(len(self.qrels_train) / self.batch_size))
|
618 |
+
|
619 |
+
def __getitem__(self, index):
|
620 |
+
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
|
621 |
+
|
622 |
+
question_texts = []
|
623 |
+
answer_texts = []
|
624 |
+
|
625 |
+
for index in indexes:
|
626 |
+
query_id = self.qrels_train[index][0]
|
627 |
+
answer_id = self.qrels_train[index][1]
|
628 |
+
question_texts.append(self.ids_to_queries[query_id])
|
629 |
+
answer_texts.append(self.ids_to_paragraphs[answer_id])
|
630 |
+
|
631 |
+
question_inputs = get_inputs_text(texts=question_texts, tokenizer=self.tokenizer,
|
632 |
+
max_length=self.max_question_length)
|
633 |
+
answer_inputs = get_inputs_text(texts=answer_texts, tokenizer=self.tokenizer,
|
634 |
+
max_length=self.max_answer_length)
|
635 |
+
|
636 |
+
x = {
|
637 |
+
'question_input_word_ids': question_inputs[0],
|
638 |
+
'question_input_masks': question_inputs[1],
|
639 |
+
'question_input_segments': question_inputs[2],
|
640 |
+
'answer_input_word_ids': answer_inputs[0],
|
641 |
+
'answer_input_masks': answer_inputs[1],
|
642 |
+
'answer_input_segments': answer_inputs[2]
|
643 |
+
}
|
644 |
+
y = np.zeros((self.batch_size, self.batch_size))
|
645 |
+
|
646 |
+
return x, y
|
647 |
+
|
648 |
+
def on_epoch_end(self, logs=None):
|
649 |
+
self.indexes = np.arange(len(self.qrels_train))
|
650 |
+
np.random.shuffle(self.indexes)
|
651 |
+
"""
|
652 |
+
|
653 |
+
|
654 |
+
def create_qa_encoder_pretraining_dataset(dataset_file, output_file, total_samples):
|
655 |
+
print('Loading paragraphs in memory...')
|
656 |
+
|
657 |
+
articles = []
|
658 |
+
file = open(dataset_file, 'r', encoding='utf8')
|
659 |
+
while True:
|
660 |
+
line = file.readline()
|
661 |
+
if not line:
|
662 |
+
break
|
663 |
+
line = line.rstrip('\n')
|
664 |
+
data = json.loads(line)
|
665 |
+
articles.append(data['paragraphs'])
|
666 |
+
|
667 |
+
print('Loaded paragraphs in memory')
|
668 |
+
|
669 |
+
o = open(output_file, 'w+')
|
670 |
+
|
671 |
+
added = 0
|
672 |
+
while added < total_samples:
|
673 |
+
index = randint(0, len(articles) - 1)
|
674 |
+
article = articles[index]
|
675 |
+
if len(article) < 2:
|
676 |
+
continue
|
677 |
+
if added % 2 == 0: # Body-First-Selection
|
678 |
+
paragraph = article[0]
|
679 |
+
random_question_sentence_id = randint(0, len(paragraph) - 1)
|
680 |
+
question = paragraph[random_question_sentence_id]
|
681 |
+
random_answer_paragraph_id = randint(1, len(article) - 1)
|
682 |
+
answer = ' '.join(article[random_answer_paragraph_id])
|
683 |
+
else: # Inverse-Cloze-Task
|
684 |
+
random_paragraph_id = randint(0, len(article) - 1)
|
685 |
+
paragraph = article[random_paragraph_id]
|
686 |
+
random_question_sentence_id = randint(0, len(paragraph) - 1)
|
687 |
+
question = paragraph[random_question_sentence_id]
|
688 |
+
choice = random()
|
689 |
+
if choice < 0.9:
|
690 |
+
answer = ' '.join([paragraph[j] for j in range(len(paragraph)) if j != random_question_sentence_id])
|
691 |
+
else:
|
692 |
+
answer = ' '.join(paragraph)
|
693 |
+
|
694 |
+
o.write(json.dumps({'s1': question, 's2': answer}) + '\n')
|
695 |
+
|
696 |
+
added += 1
|
697 |
+
if added % 10000 == 0:
|
698 |
+
print(added)
|
699 |
+
|
700 |
+
o.close()
|
701 |
+
|
702 |
+
|
703 |
+
def create_random_wiki_paragraphs(dataset_file, output_file, total_samples, max_length=510):
|
704 |
+
print('Loading paragraphs in memory...')
|
705 |
+
|
706 |
+
articles = []
|
707 |
+
file = open(dataset_file, 'r', encoding='utf8')
|
708 |
+
while True:
|
709 |
+
line = file.readline()
|
710 |
+
if not line:
|
711 |
+
break
|
712 |
+
line = line.rstrip('\n')
|
713 |
+
data = json.loads(line)
|
714 |
+
articles.append(data['paragraphs'])
|
715 |
+
|
716 |
+
print('Loaded paragraphs in memory')
|
717 |
+
|
718 |
+
o = open(output_file, 'w+')
|
719 |
+
added = 0
|
720 |
+
while added < total_samples:
|
721 |
+
index = randint(0, len(articles) - 1)
|
722 |
+
article = articles[index]
|
723 |
+
if len(article) < 1:
|
724 |
+
continue
|
725 |
+
random_paragraph_id = randint(0, len(article) - 1)
|
726 |
+
paragraph = article[random_paragraph_id]
|
727 |
+
|
728 |
+
p = []
|
729 |
+
words = 0
|
730 |
+
i = 0
|
731 |
+
while i < len(paragraph):
|
732 |
+
nwords = len(paragraph[i].split(' '))
|
733 |
+
if nwords + words >= max_length:
|
734 |
+
break
|
735 |
+
p.append(paragraph[i])
|
736 |
+
i += 1
|
737 |
+
|
738 |
+
o.write(json.dumps({'text': ' '.join(p)}) + '\n')
|
739 |
+
added += 1
|
740 |
+
if added % 10000 == 0:
|
741 |
+
print(added)
|
742 |
+
o.close()
|
743 |
+
|
744 |
+
|
745 |
+
def load_unsupervised_dataset(dataset_file):
|
746 |
+
print('Loading dataset...')
|
747 |
+
x = pickle.load(open(dataset_file, "rb"))
|
748 |
+
print('Done')
|
749 |
+
return x, len(x[0])
|
750 |
+
|
751 |
+
|
752 |
+
def create_entailment_dataset(dataset_file, tokenizer, output_file):
|
753 |
+
x = [[] for _ in range(6)]
|
754 |
+
y = []
|
755 |
+
|
756 |
+
with open(dataset_file, encoding='utf8') as file:
|
757 |
+
for j, line in enumerate(file):
|
758 |
+
if j % 10000 == 0:
|
759 |
+
print(j)
|
760 |
+
line = line.rstrip('\n')
|
761 |
+
sample = json.loads(line)
|
762 |
+
label = sample['label']
|
763 |
+
s1 = tokenize(text=sample['s1'],
|
764 |
+
max_length=256,
|
765 |
+
tokenizer=tokenizer)
|
766 |
+
s2 = tokenize(text=sample['s2'],
|
767 |
+
max_length=256,
|
768 |
+
tokenizer=tokenizer)
|
769 |
+
example = s1 + s2
|
770 |
+
for i in range(6):
|
771 |
+
x[i].append(example[i])
|
772 |
+
if label == 'entailment':
|
773 |
+
label = [1.0, 0.0, 0.0]
|
774 |
+
elif label == 'contradiction':
|
775 |
+
label = [0.0, 1.0, 0.0]
|
776 |
+
else:
|
777 |
+
label = [0.0, 0.0, 1.0]
|
778 |
+
y.append(np.asarray(label, dtype='float32'))
|
779 |
+
|
780 |
+
for i in range(6):
|
781 |
+
x[i] = np.asarray(x[i])
|
782 |
+
|
783 |
+
y = np.asarray(y)
|
784 |
+
data = [x, y]
|
785 |
+
|
786 |
+
pickle.dump(data, open(output_file, "wb"), protocol=4)
|
787 |
+
|
788 |
+
|
789 |
+
def load_supervised_dataset(dataset_file):
|
790 |
+
print('Loading dataset...')
|
791 |
+
d = pickle.load(open(dataset_file, "rb"))
|
792 |
+
print('Done')
|
793 |
+
return d[0], d[1]
|
794 |
+
|
795 |
+
|
796 |
+
def preprocess_sg():
|
797 |
+
fo = open('../data/qa/squad/dev_sg_sq.txt', 'w+')
|
798 |
+
f = open('../data/qa/squad/dev.jsonl', 'r')
|
799 |
+
for line in f:
|
800 |
+
d = json.loads(line.rstrip('\n'))
|
801 |
+
q = d['question']
|
802 |
+
a = d['answer']
|
803 |
+
|
804 |
+
if q.endswith('.'):
|
805 |
+
q = q[:len(q) - 1] + '?'
|
806 |
+
if not q.endswith('?'):
|
807 |
+
q += ' ?'
|
808 |
+
|
809 |
+
fo.write(q + ' ' + a + '\n')
|
810 |
+
|
811 |
+
|
812 |
+
def preprocess_msmarco_():
|
813 |
+
fo = open('../data/qa/msmarco/dev_sg.tsv', 'w+')
|
814 |
+
with open('../data/qa/msmarco/dev.jsonl', 'r') as f:
|
815 |
+
for line in f:
|
816 |
+
data = json.loads(line.rstrip('\n'))
|
817 |
+
if data['wellFormedAnswers'] == '[]' or data['query_type'] == 'DESCRIPTION':
|
818 |
+
continue
|
819 |
+
q = data['query']
|
820 |
+
a = data['answers'][0]
|
821 |
+
fa = data['wellFormedAnswers'][0]
|
822 |
+
|
823 |
+
if len(a) > len(q):
|
824 |
+
continue
|
825 |
+
|
826 |
+
if q.endswith('.'):
|
827 |
+
q = q[:len(q) - 1] + '?'
|
828 |
+
if not q.endswith('?'):
|
829 |
+
q += ' ?'
|
830 |
+
|
831 |
+
if '_' in q:
|
832 |
+
continue
|
833 |
+
|
834 |
+
fo.write('{} {}\t{}\n'.format(q, a, fa))
|
835 |
+
|
836 |
+
|
837 |
+
def create_qnli_nq():
|
838 |
+
files = ['train', 'dev']
|
839 |
+
for file in files:
|
840 |
+
o = open('../data/qnli/{}_nq.jsonl'.format(file), 'w+')
|
841 |
+
|
842 |
+
fs = open('../data/qnli/{}.txt'.format(file))
|
843 |
+
statements = [s.rstrip('\n') for s in fs]
|
844 |
+
|
845 |
+
fd = open('../data/qa/nq/{}_nli.jsonl'.format(file))
|
846 |
+
for i, line in enumerate(fd):
|
847 |
+
d = json.loads(line.rstrip('\n'))
|
848 |
+
o.write(json.dumps({
|
849 |
+
'statement': statements[i],
|
850 |
+
'question': d['question'],
|
851 |
+
'answer': d['answer'],
|
852 |
+
'evidence': d['evidence']
|
853 |
+
}) + '\n')
|
854 |
+
|
855 |
+
|
856 |
+
def create_qnli_ms():
|
857 |
+
files = ['train', 'dev']
|
858 |
+
for file in files:
|
859 |
+
o = open('../data/qnli/{}_ms.jsonl'.format(file), 'w+')
|
860 |
+
with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
|
861 |
+
for line in f:
|
862 |
+
data = json.loads(line.rstrip('\n'))
|
863 |
+
if data['wellFormedAnswers'] == '[]' or data['query_type'] == 'DESCRIPTION':
|
864 |
+
continue
|
865 |
+
q = data['query']
|
866 |
+
a = data['answers'][0]
|
867 |
+
fa = data['wellFormedAnswers'][0]
|
868 |
+
|
869 |
+
evidence = ''
|
870 |
+
for p in data['passages']:
|
871 |
+
if p['is_selected'] == 1:
|
872 |
+
evidence = p['passage_text']
|
873 |
+
if evidence == '' or '_' in q:
|
874 |
+
continue
|
875 |
+
|
876 |
+
o.write(json.dumps({
|
877 |
+
'statement': fa,
|
878 |
+
'question': q,
|
879 |
+
'answer': a,
|
880 |
+
'evidence': evidence
|
881 |
+
}) + '\n')
|
882 |
+
|
883 |
+
|
884 |
+
def create_qnli_ms_2():
|
885 |
+
files = ['train', 'dev']
|
886 |
+
for file in files:
|
887 |
+
o = open('../data/qnli/{}_ms_nw.jsonl'.format(file), 'w+')
|
888 |
+
fs = open('../data/qa/msmarco/statements_{}_ms.txt'.format(file))
|
889 |
+
statements = [s.rstrip('\n') for s in fs]
|
890 |
+
with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
|
891 |
+
n = 0
|
892 |
+
for line in f:
|
893 |
+
data = json.loads(line.rstrip('\n'))
|
894 |
+
if data['wellFormedAnswers'] != '[]' or data['query_type'] == 'DESCRIPTION':
|
895 |
+
continue
|
896 |
+
q = data['query']
|
897 |
+
a = data['answers'][0]
|
898 |
+
|
899 |
+
if a == 'No Answer Present.' or '_' in q:
|
900 |
+
continue
|
901 |
+
|
902 |
+
evidence = ''
|
903 |
+
for p in data['passages']:
|
904 |
+
if p['is_selected'] == 1:
|
905 |
+
evidence = p['passage_text']
|
906 |
+
|
907 |
+
if '\n' in q or '\n' in a:
|
908 |
+
nl = (q + a).count('\n') + 1
|
909 |
+
print(nl)
|
910 |
+
n += nl
|
911 |
+
continue
|
912 |
+
|
913 |
+
o.write(json.dumps({
|
914 |
+
'statement': statements[n],
|
915 |
+
'question': q,
|
916 |
+
'answer': a,
|
917 |
+
'evidence': evidence
|
918 |
+
}) + '\n')
|
919 |
+
|
920 |
+
n += 1
|
921 |
+
|
922 |
+
|
923 |
+
def create_msmarco_sg():
|
924 |
+
files = ['train', 'dev']
|
925 |
+
for file in files:
|
926 |
+
o = open('../data/qa/msmarco/{}_ms_sg.txt'.format(file), 'w+')
|
927 |
+
with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
|
928 |
+
for line in f:
|
929 |
+
data = json.loads(line.rstrip('\n'))
|
930 |
+
if data['wellFormedAnswers'] != '[]' or data['query_type'] == 'DESCRIPTION':
|
931 |
+
continue
|
932 |
+
q = data['query']
|
933 |
+
a = data['answers'][0]
|
934 |
+
|
935 |
+
if a == 'No Answer Present.' or '_' in q:
|
936 |
+
continue
|
937 |
+
|
938 |
+
if q.endswith('.'):
|
939 |
+
q = q[:len(q) - 1] + '?'
|
940 |
+
if not q.endswith('?'):
|
941 |
+
q += ' ?'
|
942 |
+
|
943 |
+
o.write(q + ' ' + a + '\n')
|
944 |
+
"""
|
945 |
+
o.write(json.dumps({
|
946 |
+
'statement': fa,
|
947 |
+
'question': q,
|
948 |
+
'answer': a,
|
949 |
+
'evidence': evidence
|
950 |
+
}) + '\n')
|
951 |
+
"""
|
952 |
+
|
953 |
+
|
954 |
+
def create_qar():
|
955 |
+
files = ['train', 'dev']
|
956 |
+
for file in files:
|
957 |
+
o = open('../data/qar/{}2.jsonl'.format(file), 'w+')
|
958 |
+
|
959 |
+
f = open('../data/qa/nq/{}_all.jsonl'.format(file))
|
960 |
+
for line in f:
|
961 |
+
d = json.loads(line.rstrip('\n'))
|
962 |
+
if d['evidence'] == '':
|
963 |
+
continue
|
964 |
+
o.write(json.dumps({
|
965 |
+
'question': d['question'],
|
966 |
+
'answer': d['evidence'],
|
967 |
+
'source': 'natural-questions'
|
968 |
+
}) + '\n')
|
969 |
+
|
970 |
+
f = open('../data/qa/msmarco/{}_all.jsonl'.format(file))
|
971 |
+
for line in f:
|
972 |
+
d = json.loads(line.rstrip('\n'))
|
973 |
+
if d['evidence'] == '':
|
974 |
+
continue
|
975 |
+
o.write(json.dumps({
|
976 |
+
'question': d['question'],
|
977 |
+
'answer': d['evidence'],
|
978 |
+
'source': 'msmarco'
|
979 |
+
}) + '\n')
|
980 |
+
|
981 |
+
|
982 |
+
"""
|
983 |
+
question_inputs = get_inputs_text(texts=question_texts, tokenizer=tokenizer,
|
984 |
+
max_length=max_question_length)
|
985 |
+
answer_inputs = get_inputs_text(texts=answer_texts, tokenizer=tokenizer,
|
986 |
+
max_length=max_answer_length)
|
987 |
+
|
988 |
+
x = {
|
989 |
+
'question_input_word_ids': question_inputs[0],
|
990 |
+
'question_input_masks': question_inputs[1],
|
991 |
+
'question_input_segments': question_inputs[2],
|
992 |
+
'answer_input_word_ids': answer_inputs[0],
|
993 |
+
'answer_input_masks': answer_inputs[1],
|
994 |
+
'answer_input_segments': answer_inputs[2]
|
995 |
+
}
|
996 |
+
y = np.zeros((self.batch_size, self.batch_size))
|
997 |
+
"""
|
998 |
+
|
999 |
+
"""
|
1000 |
+
create_random_wiki_paragraphs('../data/wikipedia/wiki_paragraphs.jsonl',
|
1001 |
+
'../data/random_wiki_paragraphs.jsonl',
|
1002 |
+
total_samples=1000000,
|
1003 |
+
max_length=510)
|
1004 |
+
"""
|
1005 |
+
|
1006 |
+
# preprocess_nq_dataset()
|
models/dense_retriever.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import logging
|
3 |
+
import pickle
|
4 |
+
import sqlite3
|
5 |
+
import struct
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from models.vector_index import VectorIndex
|
10 |
+
|
11 |
+
import random
|
12 |
+
|
13 |
+
|
14 |
+
class DenseRetriever:
|
15 |
+
def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False):
|
16 |
+
self.model = model
|
17 |
+
self.vector_index = VectorIndex(768)
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.use_gpu = use_gpu
|
20 |
+
self.db = sqlite3.connect(db_path)
|
21 |
+
self.db.row_factory = sqlite3.Row
|
22 |
+
self.debug = debug
|
23 |
+
|
24 |
+
def load_pretrained_index(self, path):
|
25 |
+
self.vector_index.load(path)
|
26 |
+
|
27 |
+
def populate_index(self, table_name):
|
28 |
+
cur = self.db.cursor()
|
29 |
+
query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \
|
30 |
+
else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000'
|
31 |
+
for r in cur.execute(query):
|
32 |
+
e = r['encoded']
|
33 |
+
v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))]
|
34 |
+
self.vector_index.index.add(np.ascontiguousarray([v]))
|
35 |
+
print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='')
|
36 |
+
print()
|
37 |
+
logging.info("Finished adding vectors.")
|
38 |
+
|
39 |
+
def search(self, queries, limit=1000, probes=512, min_similarity=0):
|
40 |
+
query_vectors = self.model.encode(queries, batch_size=self.batch_size)
|
41 |
+
ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes)
|
42 |
+
results = []
|
43 |
+
for j in range(len(ids)):
|
44 |
+
results.append([
|
45 |
+
(ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity
|
46 |
+
])
|
47 |
+
return results
|
models/encoder/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/encoder/0_BERT/config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"output_past": true,
|
17 |
+
"pad_token_id": 0,
|
18 |
+
"type_vocab_size": 2,
|
19 |
+
"vocab_size": 28996
|
20 |
+
}
|
models/encoder/0_BERT/sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 256,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
models/encoder/0_BERT/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
models/encoder/0_BERT/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": false, "max_len": 512, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
models/encoder/0_BERT/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
models/encoder/1_Pooling/config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 768,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
models/encoder/config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": "1.0"
|
3 |
+
}
|
models/encoder/modules.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "0_BERT",
|
6 |
+
"type": "sentence_transformers.models.BERT"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
}
|
14 |
+
]
|
models/encoder/stats.csv
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0.5779840120767146,0.8015233128623872,0.8583559199917659,0.8996466188630048,0.9569938587161629,0.6784459557674578
|
2 |
+
0.5916046248327443,0.817116684392905,0.8708958040278588,0.9097505746732082,0.9613682368682883,0.6921505611721683
|
3 |
+
0.6190860122825677,0.8418190551343191,0.8930078567262497,0.927230932857584,0.970237074141421,0.7180167815120542
|
4 |
+
0.6270628195011494,0.8501046419871685,0.8993378392287371,0.9313308402236937,0.9730847085463341,0.7257908699942954
|
5 |
+
0.6326380073420935,0.8536212989329948,0.9017051497581227,0.9325831131848904,0.9728273921844444,0.7306655399664306
|
6 |
+
0.6432565958760764,0.8622328198442378,0.9086698459532714,0.9393248018664013,0.9762754314337667,0.740502524041259
|
7 |
+
0.6463272377946272,0.8637252547431983,0.9105739870312554,0.94061138367585,0.9756750265893573,0.7430413705788191
|
8 |
+
0.6501698287988472,0.8689573541016228,0.9141592616735856,0.9429100765087316,0.9773390057295777,0.7468761817476934
|
9 |
+
0.6538580299859333,0.8707757230589769,0.9151542182728926,0.9432531649912512,0.9766699831886644,0.7501280666772407
|
10 |
+
0.6536007136240436,0.8742066078841734,0.9184993309774591,0.9457920197618966,0.9773218513054517,0.751086942266862
|
11 |
+
0.6515593371530518,0.8723539300785672,0.9152399903935224,0.9439908052286685,0.9773561601537036,0.7486908765241613
|
12 |
+
0.6565512745737125,0.871650598689402,0.9158060863896799,0.9442824304388102,0.9770645349435619,0.7520715039306786
|
13 |
+
0.6568600542079802,0.8752358733317323,0.919545750849144,0.9464781967269359,0.9773733145778296,0.7533833000405535
|
14 |
+
0.6647682437300579,0.8794387072425979,0.922376230829931,0.9491371324664631,0.9790716025663019,0.7595162032832385
|
15 |
+
0.6617490650838851,0.8790270010635743,0.919528596425018,0.9472158369643531,0.9780938003911208,0.7577154555045574
|
16 |
+
0.6612172779359797,0.8765739184135588,0.9201290012694274,0.945397468006999,0.9761381960407589,0.756179860041224
|
17 |
+
0.664836861426562,0.8797989501492435,0.9218787525302775,0.9485367276220538,0.9781452636634989,0.7597794971475365
|
18 |
+
0.6645795450646722,0.8795073249391018,0.9216557450166398,0.9480220948982743,0.9778707928774831,0.7592882243172064
|
19 |
+
0.6696229457577109,0.8815658558342196,0.9232682608844821,0.9485367276220538,0.9783511167530106,0.7631352909333269
|
20 |
+
0.6653686485744673,0.8806052080831647,0.9231824887638522,0.9483480289566679,0.9792259923834357,0.7605593087121496
|
21 |
+
0.6703262771468762,0.8829210553401722,0.924383298452671,0.9496689196143685,0.9781967269358768,0.763841387028115
|
22 |
+
0.6659004357223728,0.8810169142621882,0.9229594812502144,0.9479020139293924,0.9770816893676879,0.7606746681736507
|
23 |
+
0.6761587813497101,0.8852025937489278,0.9259100421998834,0.9505437952447936,0.9787456685079082,0.7682947389867827
|
24 |
+
0.6741688681510961,0.8837273132740934,0.9250866298418362,0.9504580231241637,0.9782996534806326,0.7671206271858944
|
25 |
+
0.6687137612790338,0.882183415102755,0.9244862249974268,0.9500120080968882,0.9785569698425224,0.7634663607979358
|
26 |
+
0.6759014649878203,0.8859745428345971,0.9264418293477888,0.9514529797234706,0.9791745291110577,0.7687004551272422
|
27 |
+
0.6738600885168284,0.8873640511888016,0.9274367859470958,0.9514015164510927,0.9790887569904279,0.7675660379918043
|
28 |
+
0.6721446461042303,0.8834185336398257,0.9238858201530175,0.9491885957388411,0.9771503070641918,0.7653706174723683
|
29 |
+
0.673396919065427,0.8885991697258723,0.9287233677565444,0.9533399663773288,0.9803581843757505,0.7681142393185553
|
30 |
+
0.677205201221395,0.8848766596905342,0.9255669537173637,0.9494802209489828,0.9775963220914674,0.7691179334267018
|
31 |
+
0.6725906611315058,0.8832984526709439,0.9240916732425293,0.9494116032524789,0.9772189247606958,0.765669866288031
|
32 |
+
0.679126496723505,0.8885648608776203,0.9288091398771743,0.9522249288091399,0.978900058325042,0.7717041464381332
|
33 |
+
0.6765018698322297,0.8865749476790065,0.9267849178303085,0.9513500531787148,0.9794146910488215,0.7692736782586759
|
34 |
+
0.6792122688441349,0.8870724259786599,0.9285689779394106,0.9536144371633444,0.9801351768621127,0.771438903068368
|
35 |
+
0.6806360860465914,0.8879816104573369,0.9286204412117885,0.9532370398325728,0.9795690808659553,0.772461625855239
|
36 |
+
0.677170892373143,0.8858716162898411,0.926287439530655,0.9515044429958486,0.9794661543211994,0.7698102128989501
|
37 |
+
0.685816722132638,0.8909493258311318,0.9305588911380245,0.9536830548598484,0.9806498095858922,0.776847345661558
|
38 |
+
0.6874292380004803,0.8926304593954781,0.9318969362198511,0.9543177685525097,0.9809757436442859,0.778282769703823
|
39 |
+
0.6828490067588431,0.8883418533639825,0.9285003602429066,0.9527567159570454,0.9795176175935774,0.774156750244517
|
40 |
+
0.6826946169417093,0.890726318317494,0.9309705973170481,0.954986791093423,0.980889971523656,0.7746702901532684
|
41 |
+
0.679143651147631,0.8890451847531479,0.9278313377019933,0.9516759872371084,0.9774419322743335,0.771751324626232
|
42 |
+
0.6816310426458984,0.8901602223213366,0.9304731190173946,0.9536144371633444,0.9792088379593097,0.7738497806623483
|
43 |
+
0.682986242151851,0.8900229869283288,0.929066456239064,0.95205338456788,0.9785741242666484,0.7741901341300708
|
44 |
+
0.6789549524822451,0.8901945311695887,0.9288949119978043,0.9527052526846673,0.9786084331149003,0.7717603788208348
|
45 |
+
0.687738017634748,0.8921844443682025,0.9321027893093629,0.9541290698871239,0.9797749339554671,0.778173339236671
|
46 |
+
0.6816653514941503,0.8897313617181871,0.9295124712663396,0.9536487460115964,0.9787456685079082,0.7736270773705498
|
47 |
+
0.6899852471952517,0.8917212749168011,0.930095721686623,0.9531512677119429,0.9784197344495146,0.7792509607389516
|
48 |
+
0.68392973547878,0.8919271280063128,0.9309191340446701,0.9545922393385254,0.9804096476481284,0.7758120806040231
|
49 |
+
0.687772326483,0.8929735478779978,0.9319483994922291,0.9544893127937695,0.980907125947782,0.7784736944841473
|
50 |
+
0.6842042062647957,0.8922359076405805,0.9316224654338354,0.9548838645486671,0.980804199403026,0.7761278134585238
|
51 |
+
0.6846159124438193,0.8934538717535252,0.9321199437334888,0.9548495557004152,0.9802381034068687,0.7768488500009861
|
52 |
+
0.6882354959344015,0.8937626513877929,0.9338010772978351,0.9569423954437849,0.9811815967337977,0.7791075653748736
|
53 |
+
0.6855250969224963,0.8929049301814939,0.9323086423988747,0.9555871959378324,0.9802895666792466,0.77714622677133
|
54 |
+
0.6878752530277559,0.8921501355199506,0.9322057158541188,0.955021099941675,0.9800494047414828,0.778441356120615
|
55 |
+
0.684873228805709,0.8899543692318249,0.9298212509006073,0.9535458194668405,0.9789686760215459,0.7759056803242941
|
56 |
+
0.684890383229835,0.8920815178234467,0.931982708340481,0.9547980924280371,0.9799979414691049,0.7768483470243297
|
57 |
+
0.6939479191683535,0.8962157340378083,0.9336466874807012,0.9556901224825882,0.9813531409750574,0.7834821941604841
|
58 |
+
0.6864514358252993,0.8933680996328953,0.9311936048306858,0.9543863862490136,0.9796376985624593,0.7777730592023087
|
59 |
+
0.6901567914365114,0.8946889902905959,0.9330634370604178,0.9547980924280371,0.9795347720177033,0.7806016928782297
|
60 |
+
0.6914776820942121,0.8939513500531787,0.9313479946478197,0.9534600473462106,0.9786598963872782,0.7812826686600813
|
61 |
+
0.6920266236662436,0.894637527018218,0.9334408343911895,0.9557758946032182,0.9812502144303016,0.781829703594363
|
62 |
+
0.6930215802655505,0.8951521597419975,0.9326517308813943,0.9553298795759426,0.9809242803719079,0.7825955960750406
|
63 |
+
0.6919923148179915,0.8950149243489895,0.9326517308813943,0.9554499605448246,0.9804096476481284,0.781802873512625
|
64 |
+
0.6883727313274094,0.8899543692318249,0.9277970288537414,0.9525508628675335,0.9770816893676879,0.7779526954120721
|
65 |
+
0.6861769650392836,0.8913781864342813,0.9306275088345284,0.9539575256458641,0.9791059114145538,0.7770877463274903
|
66 |
+
0.6918550794249837,0.8962328884619344,0.9347617250488901,0.9580059697395958,0.9829313479946479,0.7821665815849399
|
67 |
+
0.6930901979620544,0.8964044327031941,0.9333550622705595,0.9554671149689505,0.9809414347960339,0.7827637999409208
|
68 |
+
0.6845816035955673,0.8918928191580608,0.9293237726009538,0.9518303770542423,0.9778364840292312,0.775727886193286
|
69 |
+
0.6941023089854873,0.8980684118434145,0.9342470923251106,0.9558788211479741,0.981833464850585,0.7839931362022297
|
70 |
+
0.689573541016228,0.8941572031426905,0.9325831131848904,0.9552269530311868,0.9811815967337977,0.7801137155274634
|
71 |
+
0.6899680927711257,0.891807047037431,0.9305760455621505,0.9524136274745256,0.9781452636634989,0.7794242917082135
|
72 |
+
0.6926441829347789,0.895598174769273,0.9327889662744022,0.9556215047860843,0.9811472878855457,0.7823044597371529
|
73 |
+
0.6946855594057707,0.8966789034892099,0.9346244896558823,0.9555871959378324,0.9814217586715615,0.7841933350778004
|
74 |
+
0.690654269736165,0.8944659827769582,0.9323257968230007,0.9552441074553127,0.9805297286170104,0.7806231726203139
|
75 |
+
0.6921810134833773,0.895649638041651,0.9332006724534258,0.955913129996226,0.9812159055820496,0.7820928716391452
|
76 |
+
0.6918207705767317,0.895529557072769,0.9342985555974886,0.9558445122997221,0.9821250900607267,0.7823101370916966
|
77 |
+
0.6965382372113769,0.899440765773493,0.9366658661268741,0.9581775139808557,0.9829313479946479,0.7861867148449412
|
78 |
+
0.6914433732459602,0.8950320787731156,0.9328232751226542,0.9559817476927299,0.9814389130956874,0.7817211414999055
|
79 |
+
0.6889731361718188,0.8940542765979346,0.931999862764607,0.9545922393385254,0.9805983463135143,0.7797127544519308
|
80 |
+
0.6913061378529523,0.8949634610766116,0.9325144954883865,0.9561189830857378,0.9811815967337977,0.781525423126717
|
81 |
+
0.6936048306858339,0.896473050399698,0.9326002676090164,0.9560846742374859,0.9814903763680654,0.7835110543483639
|
82 |
+
0.6941023089854873,0.8942258208391944,0.9313308402236937,0.9537859814046042,0.978917212749168,0.7828756585433598
|
83 |
+
0.6987168490753766,0.8988403609290836,0.9345901808076302,0.9566336158095172,0.9807527361306481,0.787165665870235
|
84 |
+
0.69545750849144,0.8965416680962021,0.9326860397296463,0.9552269530311868,0.9796376985624593,0.7842072813024578
|
85 |
+
0.6983737605928569,0.8981370295399184,0.9356537551034412,0.9574055648951865,0.9815075307921913,0.7869324771044554
|
86 |
+
0.696383847394243,0.8968676021545957,0.9340583936597249,0.9562562184787456,0.9800150958932309,0.7851641911299555
|
87 |
+
0.6974988849624318,0.8990119051703435,0.9364771674614883,0.9584348303427453,0.9822794798778605,0.7869184646060629
|
88 |
+
0.6960750677599753,0.8977939410573987,0.9344872542628744,0.9556729680584622,0.979912169348475,0.7852065186352417
|
89 |
+
0.6984938415617388,0.8997838542560126,0.936785947095756,0.9581775139808557,0.9815933029128212,0.787530466944267
|
90 |
+
0.6984766871376128,0.8989261330497135,0.935876762617079,0.9571997118056746,0.9802037945586166,0.7871252651258276
|
91 |
+
0.6960922221841013,0.8987888976567057,0.9356022918310632,0.956959549867911,0.9812330600061756,0.7855466480210304
|
92 |
+
0.6972415686005421,0.8972278450612413,0.9343157100216146,0.9560675198133599,0.9810786701890417,0.7856690263774906
|
93 |
+
0.6998661954918174,0.8990805228668474,0.9345901808076302,0.9566679246577693,0.9809242803719079,0.7876728318822217
|
94 |
+
0.6978248190208255,0.8990290595944694,0.9359968435859608,0.9581260507084777,0.9816276117610732,0.7868541067616021
|
95 |
+
0.6961265310323532,0.8981198751157924,0.933852540570213,0.9557415857549663,0.9791402202628058,0.7853667716098274
|
96 |
+
0.7032113081963839,0.9011047449137133,0.9365972484303702,0.9581088962843517,0.9811472878855457,0.7905386234241963
|
97 |
+
0.7025422856554705,0.8997323909836347,0.9361169245548427,0.9574227193193124,0.9810958246131677,0.7898037355362394
|
98 |
+
0.6948399492229046,0.8977253233608947,0.9346587985041342,0.9565993069612653,0.9812845232785535,0.7844807566485703
|
99 |
+
0.6989227021648883,0.8963358150066902,0.9337839228737091,0.9558959755721,0.9803067211033726,0.7864312940359843
|
100 |
+
0.6998661954918174,0.8990462140185954,0.934796033897142,0.9556215047860843,0.979912169348475,0.7876511656582856
|
101 |
+
0.6971729509040382,0.8959412632517927,0.9324115689436305,0.954008988918242,0.9792088379593097,0.7851343541574478
|
102 |
+
0.6985967681064946,0.9002984869797921,0.9374892784849213,0.9583319037979895,0.9815933029128212,0.7875866297739705
|
103 |
+
0.7022506604453288,0.8989947507462175,0.9346073352317563,0.9563419905993755,0.9803410299516245,0.7894850593461494
|
104 |
+
0.699231481799156,0.8983600370535562,0.9344700998387484,0.9559817476927299,0.979946478196727,0.7872299076283827
|
105 |
+
0.7014787113596597,0.9008817374000755,0.936820255944008,0.9582632861014856,0.981713383881703,0.7897078452703377
|
106 |
+
0.6975846570830617,0.8972621539094933,0.933989775963221,0.9556558136343363,0.9795519264418293,0.7858221981230687
|
107 |
+
0.6996260335540536,0.8993721480769891,0.9361855422513466,0.9576457268329502,0.980855662675404,0.7881251554498462
|
108 |
+
0.6976361203554396,0.9005729577658078,0.936785947095756,0.9582804405256116,0.9816276117610732,0.7872117895320041
|
109 |
+
0.7015816379044155,0.8990119051703435,0.9352077400761657,0.956856623323155,0.9794146910488215,0.7890413758536403
|
110 |
+
0.7018217998421793,0.9008302741276976,0.9366315572786221,0.9581432051326036,0.9818677736988369,0.7899203650821491
|
111 |
+
0.705956016056541,0.9030775036882012,0.9381583010258345,0.9595327134868082,0.981816310426459,0.7932645670386803
|
112 |
+
0.7033828524376436,0.9035063642913508,0.9389817133838817,0.9598414931210759,0.9826568772086322,0.7920589796937257
|
113 |
+
0.7020276529316911,0.9032147390812091,0.9387415514461179,0.9598586475452019,0.9823137887261124,0.7907239281404607
|
114 |
+
0.7005523724568566,0.9007445020070676,0.9359968435859608,0.9564620715682575,0.9800665591656088,0.7889600463763118
|
115 |
+
0.7032284626205099,0.9015336055168628,0.9372491165471575,0.957834425498336,0.9813188321268055,0.7909847811260302
|
116 |
+
0.7028853741379902,0.9019624661200123,0.9383813085394723,0.9585549113116273,0.9822108621813566,0.7910605299160958
|
117 |
+
0.7070195903523518,0.9029574227193193,0.9383641541153463,0.957851579922462,0.9813359865509315,0.7936546808268249
|
118 |
+
0.7017531821456754,0.901276289154973,0.936820255944008,0.9575771091364463,0.9807870449789,0.7899696413107273
|
119 |
+
0.7035029334065256,0.9027687240539335,0.9378323669674409,0.9588979997941469,0.9815761484886952,0.7913691735211271
|
120 |
+
0.70689950938347,0.9027858784780595,0.9388101691426218,0.9591381617319107,0.981833464850585,0.7936974704832799
|
121 |
+
0.703039763955124,0.903300511201839,0.9389817133838817,0.9593783236696745,0.9823137887261124,0.7913347277016449
|
122 |
+
0.7016502556009194,0.9015336055168628,0.9378666758156928,0.9586235290081312,0.9816790750334511,0.7902131444981638
|
123 |
+
0.7057158541187772,0.9025971798126737,0.9373177342436615,0.9591553161560367,0.9822794798778605,0.7929256361254804
|
124 |
+
0.7009640786358802,0.9010875904895873,0.9357223727999451,0.9566336158095172,0.9807184272823961,0.7890500472662872
|
125 |
+
0.7051669125467458,0.9037808350773665,0.9386043160531101,0.9593268603972964,0.9816619206093251,0.7929897654311266
|
126 |
+
0.7071225168971078,0.9051017257350671,0.9408343911894878,0.9611280749305245,0.983703297080317,0.794594734873485
|
127 |
+
0.708769341613202,0.9037636806532404,0.9391361032010156,0.9600130373623358,0.9819535458194668,0.794976115563586
|
128 |
+
0.7083404810100525,0.9041582324081381,0.9384670806601022,0.9596184856074381,0.9822108621813566,0.7949226710372033
|
129 |
+
0.7017703365698014,0.9011562081860912,0.9369574913370158,0.9579545064672179,0.9816790750334511,0.7897549305809078
|
130 |
+
0.7036573232236594,0.9018766939993824,0.9378495213915669,0.9589323086423989,0.981850619274711,0.7916034369572444
|
131 |
+
0.7054585377568875,0.9039523793186263,0.9384156173877243,0.9596184856074381,0.9826568772086322,0.7932860445089596
|
132 |
+
0.706916663807596,0.903163275808831,0.938707242597866,0.9589837719147768,0.9820050090918447,0.7939775358361482
|
133 |
+
0.7069852815040999,0.9050845713109411,0.9400624421038186,0.9609908395375167,0.9829485024187739,0.7945385075253434
|
134 |
+
0.7062647956908087,0.903163275808831,0.9378152125433149,0.9587950732493911,0.9812673688544276,0.7934864404440743
|
135 |
+
0.7067108107180842,0.9052904244004529,0.9403883761622123,0.9617456341990599,0.983703297080317,0.7943919250071902
|
136 |
+
0.7090781212474697,0.9049301814938072,0.9391017943527635,0.9599958829382098,0.9820907812124747,0.7957028931092571
|
137 |
+
0.7073626788348715,0.9060452190619961,0.9407657734929838,0.9618828695920678,0.9837890692009469,0.7950442483313808
|
138 |
+
0.7094212097299893,0.9053075788245789,0.9405084571310941,0.9614025457165403,0.9832915909012935,0.7962251490469788
|
139 |
+
0.7094898274264932,0.9045699385871616,0.9400109788314406,0.960699214327375,0.98279411260164,0.7961054241729613
|
140 |
+
0.7089065770062098,0.9050159536144372,0.9404055305863382,0.9606305966308711,0.9826911860568841,0.7957211011304709
|
141 |
+
0.7072940611383676,0.9039180704703743,0.9389302501115038,0.9589494630665248,0.9811129790372937,0.7943070084981082
|
142 |
+
0.7095584451229973,0.905153189007445,0.9410745531272515,0.9616255532301781,0.9830857378117817,0.7966290988294025
|
143 |
+
0.7099358424537688,0.9054448142175867,0.9414176416097711,0.9606820599032491,0.9828970391463958,0.7963943395382952
|
144 |
+
0.7119772189247607,0.907040175661303,0.9419494287576766,0.9619686417126977,0.9832572820530415,0.7984175961150379
|
145 |
+
0.7118056746835009,0.9052561155522009,0.9400109788314406,0.9593783236696745,0.9812845232785535,0.797624930996624
|
146 |
+
0.7138127423062408,0.9064226163927677,0.9405084571310941,0.9602875081483515,0.9821593989089786,0.798959296397886
|
147 |
+
0.7152194050845713,0.9079322057158541,0.9420695097265585,0.9617284797749339,0.9833258997495454,0.8005330807512997
|
148 |
+
0.7125433149209182,0.9075548083850825,0.9420008920300545,0.9619686417126977,0.9836518338079391,0.7988329781501703
|
149 |
+
0.7090438123992178,0.9046728651319175,0.9386043160531101,0.9594469413661784,0.9806326551617662,0.7953849481751337
|
150 |
+
0.7113768140803514,0.9046042474354136,0.9385700072048582,0.9591038528836587,0.9800665591656088,0.7968066498706435
|
151 |
+
0.7120801454695166,0.9053590420969568,0.9400795965279446,0.9601331183312176,0.9816962294575771,0.7977260981425256
|
152 |
+
0.7100559234226507,0.9055648951864685,0.9402682951933303,0.9606477510549971,0.981833464850585,0.7962313583525185
|
153 |
+
0.7126977047380519,0.9069887123889251,0.9412117885202593,0.9620201049850756,0.9824681785432463,0.7985018939595462
|
154 |
+
0.7109307990530758,0.9065770062099016,0.940645692524102,0.9599958829382098,0.9814903763680654,0.7971671390474596
|
155 |
+
0.7122516897107765,0.9061138367585,0.94069715579648,0.9602017360277215,0.9814217586715615,0.7978698135709755
|
156 |
+
0.7104676296016743,0.9054962774899646,0.940645692524102,0.9609565306892648,0.9817476927299551,0.7965702130914314
|
157 |
+
0.7113253508079733,0.9059251380931143,0.9408687000377397,0.9609222218410128,0.9826054139362541,0.7971691791256932
|
158 |
+
0.7113253508079733,0.9059594469413662,0.9399938244073146,0.9603904346931074,0.9814903763680654,0.7971970327904135
|
159 |
+
0.7127663224345558,0.9066799327546574,0.9418293477887947,0.9615054722612962,0.9822108621813566,0.7987212070933885
|
160 |
+
0.7143616838782723,0.9088928534669091,0.9430816207499915,0.962654818677737,0.9831543555082856,0.8000986555854859
|
161 |
+
0.7152365595086972,0.9098878100662161,0.9432874738395032,0.962723436374241,0.9833430541736714,0.800961004414098
|
162 |
+
0.7179641129447284,0.9099049644903421,0.9438021065632827,0.9623460390434693,0.9831886643565375,0.802537516552073
|
163 |
+
0.7143102206058942,0.9087384636497753,0.9414347960338971,0.9615740899578001,0.9825882595121281,0.7999251770665431
|
164 |
+
0.7152708683569493,0.9076577349298384,0.9418121933646687,0.9613682368682883,0.9823995608467424,0.8003246145804358
|
165 |
+
0.7156997289600988,0.9076748893539643,0.9420523553024325,0.9611452293546505,0.9821765533331046,0.8005897851903018
|
166 |
+
0.7151507873880674,0.9085497649843894,0.9421724362713144,0.9616255532301781,0.982776958177514,0.800298538385517
|
167 |
+
0.7177582598552167,0.9090815521322949,0.9427385322674717,0.9619000240161938,0.9830514289635297,0.8021783842434324
|
168 |
+
0.7155453391429649,0.9086183826808935,0.9423954437849521,0.9620372594092016,0.9829485024187739,0.8008689468904545
|
169 |
+
0.7144817648471541,0.9089614711634131,0.9423268260884482,0.961642707654304,0.9829828112670258,0.800437947894952
|
170 |
+
0.7129893299481936,0.9073146464473187,0.9409544721583697,0.9608192952962569,0.9825539506638762,0.7985213141537784
|
171 |
+
0.7184444368202559,0.9101965897004838,0.9442652760146842,0.9625175832847291,0.9839777678663327,0.8031311330802511
|
172 |
+
0.7183586646996261,0.9102823618211137,0.9441280406216763,0.9624661200123512,0.9839777678663327,0.8031333976315808
|
173 |
+
0.71842728239613,0.9080351322606101,0.9417778845164168,0.9610594572340206,0.9825882595121281,0.8024804003813029
|
174 |
+
0.7160085085943665,0.9082924486224997,0.9421895906954404,0.9607849864480049,0.9826397227845062,0.8007421064452407
|
175 |
+
0.7192163859059252,0.9107283768483893,0.9447799087384636,0.9631008337050125,0.9839263045939548,0.8037904528738081
|
176 |
+
0.7172950904038151,0.9081209043812399,0.942584142450338,0.9615397811095482,0.9822966343019864,0.801610435193648
|
177 |
+
0.7180327306412324,0.9065255429375236,0.9410745531272515,0.9600130373623358,0.9811129790372937,0.8011702379380538
|
178 |
+
0.7188561429992795,0.9108313033931451,0.9441966583181802,0.9632209146738944,0.9834459807184273,0.8034972168782052
|
179 |
+
0.7175009434933269,0.9095790304319484,0.9428071499639757,0.9612653103235325,0.9821079356366007,0.8020695604970132
|
180 |
+
0.7175009434933269,0.9095104127354444,0.9429272309328576,0.9617627886231859,0.9829313479946479,0.8021815489680041
|
181 |
+
0.7151336329639414,0.9078121247469723,0.9413490239132672,0.9610423028098947,0.9826225683603801,0.800175479488273
|
182 |
+
0.7194565478436888,0.9100593543074759,0.942566988026212,0.9621230315298316,0.9825196418156242,0.8033165975051533
|
183 |
+
0.7187189076062717,0.9100593543074759,0.9426527601468419,0.9621744948022095,0.9823995608467424,0.8029329362048628
|
184 |
+
0.7217552406765705,0.9117747967200741,0.944436820255944,0.9629807527361306,0.9831715099324115,0.8051882216053523
|
185 |
+
0.7207431296531376,0.9112087007239167,0.9439908052286685,0.962654818677737,0.982742649329262,0.8044874117887816
|
186 |
+
0.7204171955947439,0.9098706556420901,0.9441451950458023,0.9623288846193433,0.9828970391463958,0.8041066613024709
|
187 |
+
0.7210004460150273,0.9107969945448932,0.9447799087384636,0.962689127525989,0.9828798847222698,0.8045389583450916
|
188 |
+
0.7219096304937044,0.9123923559886095,0.9450886883727313,0.963632620852918,0.9833258997495454,0.8057297371947867
|
189 |
+
0.7207431296531376,0.9095447215836965,0.9431845472947473,0.9618828695920678,0.9823309431502385,0.8041280599542056
|
190 |
+
0.7218753216454523,0.9117919511442001,0.9450372251003534,0.9630665248567606,0.9829656568428998,0.8054852110655439
|
191 |
+
0.7245342573849796,0.912855525440011,0.9447455998902117,0.96358115758054,0.9836346793838131,0.8073694350037202
|
192 |
+
0.7228188149723814,0.9123752015644835,0.9446083644972039,0.9633753044910283,0.9835660616873092,0.8061516152582747
|
193 |
+
0.7228016605482553,0.9123923559886095,0.945346004734621,0.9642501801214534,0.9842007753799705,0.8064601603251264
|
194 |
+
0.7242426321748379,0.9127182900470031,0.9457577109136446,0.9643702610903352,0.9844924005901122,0.8075001427041348
|
195 |
+
0.7249459635640032,0.9134730847085464,0.9458606374584005,0.9644217243627131,0.9844237828936082,0.8081538327780139
|
196 |
+
0.7239338525405702,0.9126839811987512,0.9450200706762274,0.9633066867945244,0.9838405324733249,0.8070410918248734
|
197 |
+
0.7236250729063025,0.9124266648368614,0.9448656808590936,0.9631866058256424,0.9836003705355612,0.806859810712907
|
198 |
+
0.7224242632174838,0.911826259992452,0.9446941366178337,0.962723436374241,0.9832744364771675,0.8058241791370292
|
199 |
+
0.722218410127972,0.9117404878718222,0.9443853569835661,0.962603355405359,0.9832401276289156,0.8056427934609479
|
200 |
+
0.7226301163069956,0.9120149586578379,0.9448485264349675,0.9632552235221463,0.9833945174460493,0.8060240064701059
|
201 |
+
0.7220297114625862,0.9115860980546883,0.9442138127423062,0.962689127525989,0.9828970391463958,0.805447713601981
|
202 |
+
0.7237108450269324,0.9122722750197276,0.9448828352832196,0.9632723779462723,0.9833430541736714,0.8067396702769626
|
203 |
+
0.7235736096339246,0.9124609736851134,0.9450715339486053,0.9634096133392802,0.9836346793838131,0.8067551081005684
|
204 |
+
0.7245857206573575,0.913781864342814,0.9461007993961643,0.9646618863004769,0.9845781727107421,0.8077739135063374
|
205 |
+
0.7238309259958143,0.9125124369574913,0.9448656808590936,0.9632895323703983,0.9833087453254195,0.8068440166640881
|
206 |
+
0.7226815795793735,0.910934229937901,0.9437849521391567,0.962723436374241,0.9830171201152778,0.8056550566673708
|
207 |
+
0.7228016605482553,0.9105911414553813,0.9439050331080385,0.9623803478917212,0.9831372010841596,0.8056326358407009
|
208 |
+
0.7226129618828696,0.9121693484749717,0.9442652760146842,0.9631179881291385,0.9832915909012935,0.8061118839646912
|
209 |
+
0.7240539335094521,0.9126153635022473,0.9447455998902117,0.9634096133392802,0.9834459807184273,0.8071669628005791
|
210 |
+
0.7246200295056094,0.9130956873777747,0.9450886883727313,0.9632209146738944,0.9833773630219234,0.8075709724851876
|
211 |
+
0.7247229560503654,0.9139190997358219,0.9458091741860226,0.9640786358801935,0.9842350842282225,0.8079037101451694
|
212 |
+
0.7247058016262394,0.9128383710158849,0.9447970631625896,0.9633924589151542,0.9834802895666792,0.8075268394478943
|
213 |
+
0.7238652348440663,0.9124609736851134,0.9447627543143376,0.9631351425532645,0.9833430541736714,0.8070234679633086
|
214 |
+
0.7238652348440663,0.9124266648368614,0.9448828352832196,0.9633924589151542,0.9835489072631832,0.8069960834787101
|
215 |
+
0.7244999485367276,0.9130785329536487,0.945397468006999,0.9638556283665557,0.9840635399869626,0.8075214826372252
|
216 |
+
0.725117507805263,0.912889834288263,0.9453288503104951,0.9635983120046661,0.9838062236250729,0.8078535218173943
|
217 |
+
0.7251518166535149,0.9133186948914125,0.9457405564895186,0.9640443270319415,0.9838919957457029,0.808052022558173
|
218 |
+
0.725100353381137,0.9135931656774282,0.9460321816996603,0.9644045699385871,0.9844924005901122,0.8081660689818039
|
219 |
+
0.7251346622293889,0.9133358493155385,0.9455347034000069,0.9641300991525714,0.9838405324733249,0.8080744792082037
|
220 |
+
0.7249631179881292,0.9133701581637904,0.945397468006999,0.9638556283665557,0.9838233780491988,0.8079103037460894
|
models/nli.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import math
|
5 |
+
import argparse
|
6 |
+
import transformers
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from torch import nn
|
10 |
+
from torch.nn import DataParallel
|
11 |
+
from models.data_utils import build_batch, LoggingHandler, get_examples
|
12 |
+
from tqdm import tqdm
|
13 |
+
from sklearn.metrics import precision_recall_fscore_support
|
14 |
+
from transformers import AutoConfig, RobertaModel, AutoModel, AutoTokenizer, AdamW
|
15 |
+
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
|
20 |
+
class NLI(nn.Module):
|
21 |
+
"""
|
22 |
+
NLI model based on BERT (using the code from: https://github.com/yg211/bert_nli)
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, model_path=None, device='cuda', parallel=False, debug=False, label_num=3, batch_size=16):
|
26 |
+
super(NLI, self).__init__()
|
27 |
+
|
28 |
+
lm = 'roberta-large'
|
29 |
+
|
30 |
+
if model_path is not None:
|
31 |
+
configuration = AutoConfig.from_pretrained(lm)
|
32 |
+
self.bert = RobertaModel(configuration)
|
33 |
+
else:
|
34 |
+
self.bert = AutoModel.from_pretrained(lm)
|
35 |
+
|
36 |
+
self.tokenizer = AutoTokenizer.from_pretrained(lm)
|
37 |
+
self.vdim = 1024
|
38 |
+
self.max_length = 256
|
39 |
+
|
40 |
+
self.nli_head = nn.Linear(self.vdim, label_num)
|
41 |
+
self.batch_size = batch_size
|
42 |
+
|
43 |
+
if parallel:
|
44 |
+
self.bert = DataParallel(self.bert)
|
45 |
+
|
46 |
+
# load trained model
|
47 |
+
if model_path is not None:
|
48 |
+
sdict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
49 |
+
self.load_state_dict(sdict, strict=False)
|
50 |
+
|
51 |
+
self.to(device)
|
52 |
+
self.device = device
|
53 |
+
|
54 |
+
self.debug = debug
|
55 |
+
|
56 |
+
def load_model(self, sdict):
|
57 |
+
if self.gpu:
|
58 |
+
self.load_state_dict(sdict)
|
59 |
+
self.to('cuda')
|
60 |
+
else:
|
61 |
+
self.load_state_dict(sdict)
|
62 |
+
|
63 |
+
def forward(self, sent_pair_list):
|
64 |
+
all_probs = None
|
65 |
+
iterator = range(0, len(sent_pair_list), self.batch_size)
|
66 |
+
if self.debug:
|
67 |
+
iterator = tqdm(iterator, desc='batch')
|
68 |
+
for batch_idx in iterator:
|
69 |
+
probs = self.ff(sent_pair_list[batch_idx:batch_idx + self.batch_size]).data.cpu().numpy()
|
70 |
+
if all_probs is None:
|
71 |
+
all_probs = probs
|
72 |
+
else:
|
73 |
+
all_probs = np.append(all_probs, probs, axis=0)
|
74 |
+
labels = []
|
75 |
+
for pp in all_probs:
|
76 |
+
ll = np.argmax(pp)
|
77 |
+
if ll == 0:
|
78 |
+
labels.append('entailment')
|
79 |
+
elif ll == 1:
|
80 |
+
labels.append('contradiction')
|
81 |
+
else:
|
82 |
+
labels.append('neutral')
|
83 |
+
return labels, all_probs
|
84 |
+
|
85 |
+
def ff(self, sent_pair_list):
|
86 |
+
ids, types, masks = build_batch(self.tokenizer, sent_pair_list, max_len=self.max_length)
|
87 |
+
if ids is None:
|
88 |
+
return None
|
89 |
+
ids_tensor = torch.tensor(ids)
|
90 |
+
#ypes_tensor = torch.tensor(types)
|
91 |
+
masks_tensor = torch.tensor(masks)
|
92 |
+
|
93 |
+
ids_tensor = ids_tensor.to(self.device)
|
94 |
+
#types_tensor = types_tensor.to(self.device)
|
95 |
+
masks_tensor = masks_tensor.to(self.device)
|
96 |
+
# self.bert.to('cuda')
|
97 |
+
# self.nli_head.to('cuda')
|
98 |
+
|
99 |
+
cls_vecs = self.bert(input_ids=ids_tensor, attention_mask=masks_tensor)[1]
|
100 |
+
logits = self.nli_head(cls_vecs)
|
101 |
+
predict_probs = F.log_softmax(logits, dim=1)
|
102 |
+
return predict_probs
|
103 |
+
|
104 |
+
def save(self, output_path, config_dic=None, acc=None):
|
105 |
+
if acc is None:
|
106 |
+
model_name = 'nli_large_2.state_dict'
|
107 |
+
else:
|
108 |
+
model_name = 'nli_large_2_acc{}.state_dict'.format(acc)
|
109 |
+
opath = os.path.join(output_path, model_name)
|
110 |
+
if config_dic is None:
|
111 |
+
torch.save(self.state_dict(), opath)
|
112 |
+
else:
|
113 |
+
torch.save(config_dic, opath)
|
114 |
+
|
115 |
+
|
116 |
+
def get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
|
117 |
+
"""
|
118 |
+
Returns the correct learning rate scheduler
|
119 |
+
"""
|
120 |
+
scheduler = scheduler.lower()
|
121 |
+
if scheduler == 'constantlr':
|
122 |
+
return transformers.optimization.get_constant_schedule(optimizer)
|
123 |
+
elif scheduler == 'warmupconstant':
|
124 |
+
return transformers.optimization.get_constant_schedule_with_warmup(optimizer, warmup_steps)
|
125 |
+
elif scheduler == 'warmuplinear':
|
126 |
+
return transformers.optimization.get_linear_schedule_with_warmup(optimizer, warmup_steps, t_total)
|
127 |
+
elif scheduler == 'warmupcosine':
|
128 |
+
return transformers.optimization.get_cosine_schedule_with_warmup(optimizer, warmup_steps, t_total)
|
129 |
+
elif scheduler == 'warmupcosinewithhardrestarts':
|
130 |
+
return transformers.optimization.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, warmup_steps,
|
131 |
+
t_total)
|
132 |
+
else:
|
133 |
+
raise ValueError("Unknown scheduler {}".format(scheduler))
|
134 |
+
|
135 |
+
|
136 |
+
def train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
|
137 |
+
max_grad_norm, best_acc, model_save_path):
|
138 |
+
loss_fn = nn.CrossEntropyLoss()
|
139 |
+
model.train()
|
140 |
+
|
141 |
+
step_cnt = 0
|
142 |
+
for pointer in tqdm(range(0, len(train_data), batch_size), desc='training'):
|
143 |
+
step_cnt += 1
|
144 |
+
sent_pairs = []
|
145 |
+
labels = []
|
146 |
+
for i in range(pointer, pointer + batch_size):
|
147 |
+
if i >= len(train_data):
|
148 |
+
break
|
149 |
+
sents = train_data[i].get_texts()
|
150 |
+
sent_pairs.append(sents)
|
151 |
+
labels.append(train_data[i].get_label())
|
152 |
+
predicted_probs = model.ff(sent_pairs)
|
153 |
+
if predicted_probs is None:
|
154 |
+
continue
|
155 |
+
true_labels = torch.LongTensor(labels)
|
156 |
+
if gpu:
|
157 |
+
true_labels = true_labels.to('cuda')
|
158 |
+
loss = loss_fn(predicted_probs, true_labels)
|
159 |
+
if fp16:
|
160 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
161 |
+
scaled_loss.backward()
|
162 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
|
163 |
+
else:
|
164 |
+
loss.backward()
|
165 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
166 |
+
|
167 |
+
optimizer.step()
|
168 |
+
scheduler.step()
|
169 |
+
optimizer.zero_grad()
|
170 |
+
|
171 |
+
if step_cnt % 5000 == 0:
|
172 |
+
acc = evaluate(model, dev_data, mute=True)
|
173 |
+
logging.info('==> step {} dev acc: {}'.format(step_cnt, acc))
|
174 |
+
model.train() # model was in eval mode in evaluate(); re-activate the train mode
|
175 |
+
if acc > best_acc:
|
176 |
+
best_acc = acc
|
177 |
+
logging.info('Saving model...')
|
178 |
+
model.save(model_save_path, model.state_dict())
|
179 |
+
|
180 |
+
return best_acc
|
181 |
+
|
182 |
+
|
183 |
+
def parse_args():
|
184 |
+
ap = argparse.ArgumentParser("arguments for bert-nli training")
|
185 |
+
ap.add_argument('-b', '--batch_size', type=int, default=64, help='batch size')
|
186 |
+
ap.add_argument('-ep', '--epoch_num', type=int, default=10, help='epoch num')
|
187 |
+
ap.add_argument('--fp16', type=int, default=0, help='use apex mixed precision training (1) or not (0)')
|
188 |
+
ap.add_argument('--gpu', type=int, default=1, help='use gpu (1) or not (0)')
|
189 |
+
ap.add_argument('-ss', '--scheduler_setting', type=str, default='WarmupLinear',
|
190 |
+
choices=['WarmupLinear', 'ConstantLR', 'WarmupConstant', 'WarmupCosine',
|
191 |
+
'WarmupCosineWithHardRestarts'])
|
192 |
+
ap.add_argument('-mg', '--max_grad_norm', type=float, default=1., help='maximum gradient norm')
|
193 |
+
ap.add_argument('-wp', '--warmup_percent', type=float, default=0.1,
|
194 |
+
help='how many percentage of steps are used for warmup')
|
195 |
+
|
196 |
+
args = ap.parse_args()
|
197 |
+
return args.batch_size, args.epoch_num, args.fp16, args.gpu, args.scheduler_setting, args.max_grad_norm, args.warmup_percent
|
198 |
+
|
199 |
+
|
200 |
+
def evaluate(model, test_data, mute=False):
|
201 |
+
model.eval()
|
202 |
+
sent_pairs = [test_data[i].get_texts() for i in range(len(test_data))]
|
203 |
+
all_labels = [test_data[i].get_label() for i in range(len(test_data))]
|
204 |
+
_, probs = model(sent_pairs)
|
205 |
+
all_predict = [np.argmax(pp) for pp in probs]
|
206 |
+
assert len(all_predict) == len(all_labels)
|
207 |
+
|
208 |
+
acc = len([i for i in range(len(all_labels)) if all_predict[i] == all_labels[i]]) * 1. / len(all_labels)
|
209 |
+
prf = precision_recall_fscore_support(all_labels, all_predict, average=None, labels=[0, 1])
|
210 |
+
|
211 |
+
if not mute:
|
212 |
+
print('==>acc<==', acc)
|
213 |
+
print('==>precision-recall-f1<==\n', prf)
|
214 |
+
|
215 |
+
return acc
|
216 |
+
|
217 |
+
|
218 |
+
if __name__ == '__main__':
|
219 |
+
|
220 |
+
batch_size, epoch_num, fp16, gpu, scheduler_setting, max_grad_norm, warmup_percent = parse_args()
|
221 |
+
fp16 = bool(fp16)
|
222 |
+
gpu = bool(gpu)
|
223 |
+
|
224 |
+
print('=====Arguments=====')
|
225 |
+
print('batch size:\t{}'.format(batch_size))
|
226 |
+
print('epoch num:\t{}'.format(epoch_num))
|
227 |
+
print('fp16:\t{}'.format(fp16))
|
228 |
+
print('gpu:\t{}'.format(gpu))
|
229 |
+
print('scheduler setting:\t{}'.format(scheduler_setting))
|
230 |
+
print('max grad norm:\t{}'.format(max_grad_norm))
|
231 |
+
print('warmup percent:\t{}'.format(warmup_percent))
|
232 |
+
print('=====Arguments=====')
|
233 |
+
|
234 |
+
label_num = 3
|
235 |
+
model_save_path = 'weights/entailment'
|
236 |
+
|
237 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
238 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
239 |
+
level=logging.INFO,
|
240 |
+
handlers=[LoggingHandler()])
|
241 |
+
|
242 |
+
# Read the dataset
|
243 |
+
train_data = get_examples('../data/allnli/train.jsonl')
|
244 |
+
dev_data = get_examples('../data/allnli/dev.jsonl')
|
245 |
+
|
246 |
+
logging.info('train data size {}'.format(len(train_data)))
|
247 |
+
logging.info('dev data size {}'.format(len(dev_data)))
|
248 |
+
total_steps = math.ceil(epoch_num * len(train_data) * 1. / batch_size)
|
249 |
+
warmup_steps = int(total_steps * warmup_percent)
|
250 |
+
|
251 |
+
model = NLI(batch_size=batch_size, parallel=True)
|
252 |
+
|
253 |
+
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-6, correct_bias=False)
|
254 |
+
scheduler = get_scheduler(optimizer, scheduler_setting, warmup_steps=warmup_steps, t_total=total_steps)
|
255 |
+
if fp16:
|
256 |
+
try:
|
257 |
+
from apex import amp
|
258 |
+
except ImportError:
|
259 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
260 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
261 |
+
|
262 |
+
best_acc = -1.
|
263 |
+
for ep in range(epoch_num):
|
264 |
+
random.shuffle(train_data)
|
265 |
+
logging.info('\n=====epoch {}/{}====='.format(ep, epoch_num))
|
266 |
+
best_acc = train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
|
267 |
+
max_grad_norm, best_acc, model_save_path)
|
models/qa_ranker.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import argparse
|
6 |
+
import copy
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import numpy as np
|
11 |
+
import transformers
|
12 |
+
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import DataParallel
|
15 |
+
from transformers import BertModel, BertTokenizer
|
16 |
+
|
17 |
+
from .data_utils import build_batch, LoggingHandler, get_examples, get_qa_examples
|
18 |
+
|
19 |
+
from datetime import datetime
|
20 |
+
from tqdm import tqdm
|
21 |
+
from transformers import *
|
22 |
+
from nltk.tokenize import word_tokenize
|
23 |
+
from sklearn.metrics import precision_recall_fscore_support
|
24 |
+
|
25 |
+
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
26 |
+
#os.environ["CUDA_VISIBLE_DEVICES"] = "8,15"
|
27 |
+
|
28 |
+
|
29 |
+
class PassageRanker(nn.Module):
|
30 |
+
"""Performs prediction, given the input of BERT embeddings.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, model_path=None, gpu=True, label_num=2, batch_size=16):
|
34 |
+
super(PassageRanker, self).__init__()
|
35 |
+
|
36 |
+
lm = 'bert-large-cased'
|
37 |
+
|
38 |
+
if model_path is not None:
|
39 |
+
configuration = AutoConfig.from_pretrained(lm)
|
40 |
+
self.language_model = BertModel(configuration)
|
41 |
+
else:
|
42 |
+
self.language_model = AutoModel.from_pretrained(lm)
|
43 |
+
|
44 |
+
self.language_model = DataParallel(self.language_model)
|
45 |
+
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(lm)
|
47 |
+
self.vdim = 1024
|
48 |
+
self.max_length = 256
|
49 |
+
|
50 |
+
self.classification_head = nn.Linear(self.vdim, label_num)
|
51 |
+
self.gpu = gpu
|
52 |
+
self.batch_size = batch_size
|
53 |
+
|
54 |
+
# load trained model
|
55 |
+
if model_path is not None:
|
56 |
+
if gpu:
|
57 |
+
sdict = torch.load(model_path)
|
58 |
+
self.load_state_dict(sdict, strict=False)
|
59 |
+
self.to('cuda')
|
60 |
+
else:
|
61 |
+
sdict = torch.load(model_path, map_location=lambda storage, loc: storage)
|
62 |
+
self.load_state_dict(sdict, strict=False)
|
63 |
+
else:
|
64 |
+
if self.gpu:
|
65 |
+
self.to('cuda')
|
66 |
+
|
67 |
+
def load_model(self, sdict):
|
68 |
+
if self.gpu:
|
69 |
+
self.load_state_dict(sdict)
|
70 |
+
self.to('cuda')
|
71 |
+
else:
|
72 |
+
self.load_state_dict(sdict)
|
73 |
+
|
74 |
+
def forward(self, sent_pair_list):
|
75 |
+
all_probs = None
|
76 |
+
for batch_idx in range(0, len(sent_pair_list), self.batch_size):
|
77 |
+
probs = self.ff(sent_pair_list[batch_idx:batch_idx + self.batch_size]).data.cpu().numpy()
|
78 |
+
if all_probs is None:
|
79 |
+
all_probs = probs
|
80 |
+
else:
|
81 |
+
all_probs = np.append(all_probs, probs, axis=0)
|
82 |
+
labels = []
|
83 |
+
for pp in all_probs:
|
84 |
+
ll = np.argmax(pp)
|
85 |
+
if ll == 0:
|
86 |
+
labels.append('relevant')
|
87 |
+
else:
|
88 |
+
labels.append('irrelevant')
|
89 |
+
return labels, all_probs
|
90 |
+
|
91 |
+
def ff(self, sent_pair_list):
|
92 |
+
ids, types, masks = build_batch(self.tokenizer, sent_pair_list, max_len=self.max_length)
|
93 |
+
if ids is None:
|
94 |
+
return None
|
95 |
+
ids_tensor = torch.tensor(ids)
|
96 |
+
types_tensor = torch.tensor(types)
|
97 |
+
masks_tensor = torch.tensor(masks)
|
98 |
+
|
99 |
+
if self.gpu:
|
100 |
+
ids_tensor = ids_tensor.to('cuda')
|
101 |
+
types_tensor = types_tensor.to('cuda')
|
102 |
+
masks_tensor = masks_tensor.to('cuda')
|
103 |
+
# self.bert.to('cuda')
|
104 |
+
# self.nli_head.to('cuda')
|
105 |
+
|
106 |
+
cls_vecs = self.language_model(input_ids=ids_tensor, token_type_ids=types_tensor, attention_mask=masks_tensor)[1]
|
107 |
+
logits = self.classification_head(cls_vecs)
|
108 |
+
predict_probs = F.log_softmax(logits, dim=1)
|
109 |
+
return predict_probs
|
110 |
+
|
111 |
+
def save(self, output_path, config_dic=None, acc=None):
|
112 |
+
if acc is None:
|
113 |
+
model_name = 'qa_ranker.state_dict'
|
114 |
+
else:
|
115 |
+
model_name = 'qa_ranker_acc{}.state_dict'.format(acc)
|
116 |
+
opath = os.path.join(output_path, model_name)
|
117 |
+
if config_dic is None:
|
118 |
+
torch.save(self.state_dict(), opath)
|
119 |
+
else:
|
120 |
+
torch.save(config_dic, opath)
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def load(input_path, gpu=True, label_num=2, batch_size=16):
|
124 |
+
if gpu:
|
125 |
+
sdict = torch.load(input_path)
|
126 |
+
else:
|
127 |
+
sdict = torch.load(input_path, map_location=lambda storage, loc: storage)
|
128 |
+
model = PassageRanker(gpu=gpu, label_num=label_num, batch_size=batch_size)
|
129 |
+
model.load_state_dict(sdict)
|
130 |
+
return model
|
131 |
+
|
132 |
+
|
133 |
+
def get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
|
134 |
+
"""
|
135 |
+
Returns the correct learning rate scheduler
|
136 |
+
"""
|
137 |
+
scheduler = scheduler.lower()
|
138 |
+
if scheduler == 'constantlr':
|
139 |
+
return transformers.optimization.get_constant_schedule(optimizer)
|
140 |
+
elif scheduler == 'warmupconstant':
|
141 |
+
return transformers.optimization.get_constant_schedule_with_warmup(optimizer, warmup_steps)
|
142 |
+
elif scheduler == 'warmuplinear':
|
143 |
+
return transformers.optimization.get_linear_schedule_with_warmup(optimizer, warmup_steps, t_total)
|
144 |
+
elif scheduler == 'warmupcosine':
|
145 |
+
return transformers.optimization.get_cosine_schedule_with_warmup(optimizer, warmup_steps, t_total)
|
146 |
+
elif scheduler == 'warmupcosinewithhardrestarts':
|
147 |
+
return transformers.optimization.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, warmup_steps,
|
148 |
+
t_total)
|
149 |
+
else:
|
150 |
+
raise ValueError("Unknown scheduler {}".format(scheduler))
|
151 |
+
|
152 |
+
|
153 |
+
def train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
|
154 |
+
max_grad_norm, best_acc, model_save_path):
|
155 |
+
loss_fn = nn.CrossEntropyLoss()
|
156 |
+
model.train()
|
157 |
+
|
158 |
+
step_cnt = 0
|
159 |
+
for pointer in tqdm(range(0, len(train_data), batch_size), desc='training'):
|
160 |
+
step_cnt += 1
|
161 |
+
sent_pairs = []
|
162 |
+
labels = []
|
163 |
+
for i in range(pointer, pointer + batch_size):
|
164 |
+
if i >= len(train_data):
|
165 |
+
break
|
166 |
+
sents = train_data[i].get_texts()
|
167 |
+
sent_pairs.append(sents)
|
168 |
+
labels.append(train_data[i].get_label())
|
169 |
+
predicted_probs = model.ff(sent_pairs)
|
170 |
+
if predicted_probs is None:
|
171 |
+
continue
|
172 |
+
true_labels = torch.LongTensor(labels)
|
173 |
+
if gpu:
|
174 |
+
true_labels = true_labels.to('cuda')
|
175 |
+
loss = loss_fn(predicted_probs, true_labels)
|
176 |
+
if fp16:
|
177 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
178 |
+
scaled_loss.backward()
|
179 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
|
180 |
+
else:
|
181 |
+
loss.backward()
|
182 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
183 |
+
|
184 |
+
optimizer.step()
|
185 |
+
scheduler.step()
|
186 |
+
optimizer.zero_grad()
|
187 |
+
|
188 |
+
if step_cnt % 5000 == 0:
|
189 |
+
acc = evaluate(model, dev_data, mute=True)
|
190 |
+
logging.info('==> step {} dev acc: {}'.format(step_cnt, acc))
|
191 |
+
model.train() # model was in eval mode in evaluate(); re-activate the train mode
|
192 |
+
if acc > best_acc:
|
193 |
+
best_acc = acc
|
194 |
+
logging.info('Saving model...')
|
195 |
+
model.save(model_save_path, model.state_dict())
|
196 |
+
|
197 |
+
return best_acc
|
198 |
+
|
199 |
+
|
200 |
+
def parse_args():
|
201 |
+
ap = argparse.ArgumentParser("arguments for bert-nli training")
|
202 |
+
ap.add_argument('-b', '--batch_size', type=int, default=128, help='batch size')
|
203 |
+
ap.add_argument('-ep', '--epoch_num', type=int, default=10, help='epoch num')
|
204 |
+
ap.add_argument('--fp16', type=int, default=0, help='use apex mixed precision training (1) or not (0)')
|
205 |
+
ap.add_argument('--gpu', type=int, default=1, help='use gpu (1) or not (0)')
|
206 |
+
ap.add_argument('-ss', '--scheduler_setting', type=str, default='WarmupLinear',
|
207 |
+
choices=['WarmupLinear', 'ConstantLR', 'WarmupConstant', 'WarmupCosine',
|
208 |
+
'WarmupCosineWithHardRestarts'])
|
209 |
+
ap.add_argument('-mg', '--max_grad_norm', type=float, default=1., help='maximum gradient norm')
|
210 |
+
ap.add_argument('-wp', '--warmup_percent', type=float, default=0.1,
|
211 |
+
help='how many percentage of steps are used for warmup')
|
212 |
+
|
213 |
+
args = ap.parse_args()
|
214 |
+
return args.batch_size, args.epoch_num, args.fp16, args.gpu, args.scheduler_setting, args.max_grad_norm, args.warmup_percent
|
215 |
+
|
216 |
+
|
217 |
+
def evaluate(model, test_data, mute=False):
|
218 |
+
model.eval()
|
219 |
+
sent_pairs = [test_data[i].get_texts() for i in range(len(test_data))]
|
220 |
+
all_labels = [test_data[i].get_label() for i in range(len(test_data))]
|
221 |
+
_, probs = model(sent_pairs)
|
222 |
+
all_predict = [np.argmax(pp) for pp in probs]
|
223 |
+
assert len(all_predict) == len(all_labels)
|
224 |
+
|
225 |
+
acc = len([i for i in range(len(all_labels)) if all_predict[i] == all_labels[i]]) * 1. / len(all_labels)
|
226 |
+
prf = precision_recall_fscore_support(all_labels, all_predict, average=None, labels=[0, 1])
|
227 |
+
|
228 |
+
if not mute:
|
229 |
+
print('==>acc<==', acc)
|
230 |
+
print('label meanings: 0: relevant, 1: irrelevant')
|
231 |
+
print('==>precision-recall-f1<==\n', prf)
|
232 |
+
|
233 |
+
return acc
|
234 |
+
|
235 |
+
|
236 |
+
if __name__ == '__main__':
|
237 |
+
|
238 |
+
batch_size, epoch_num, fp16, gpu, scheduler_setting, max_grad_norm, warmup_percent = parse_args()
|
239 |
+
fp16 = bool(fp16)
|
240 |
+
gpu = bool(gpu)
|
241 |
+
|
242 |
+
print('=====Arguments=====')
|
243 |
+
print('batch size:\t{}'.format(batch_size))
|
244 |
+
print('epoch num:\t{}'.format(epoch_num))
|
245 |
+
print('fp16:\t{}'.format(fp16))
|
246 |
+
print('gpu:\t{}'.format(gpu))
|
247 |
+
print('scheduler setting:\t{}'.format(scheduler_setting))
|
248 |
+
print('max grad norm:\t{}'.format(max_grad_norm))
|
249 |
+
print('warmup percent:\t{}'.format(warmup_percent))
|
250 |
+
print('=====Arguments=====')
|
251 |
+
|
252 |
+
label_num = 2
|
253 |
+
model_save_path = 'weights/passage_ranker_2'
|
254 |
+
|
255 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
256 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
257 |
+
level=logging.INFO,
|
258 |
+
handlers=[LoggingHandler()])
|
259 |
+
|
260 |
+
# Read the dataset
|
261 |
+
train_data = get_qa_examples('../data/qa_ranking_large/train.jsonl', dev=False)
|
262 |
+
dev_data = get_qa_examples('../data/qa_ranking_large/dev.jsonl', dev=True)[:20000]
|
263 |
+
|
264 |
+
logging.info('train data size {}'.format(len(train_data)))
|
265 |
+
logging.info('dev data size {}'.format(len(dev_data)))
|
266 |
+
total_steps = math.ceil(epoch_num * len(train_data) * 1. / batch_size)
|
267 |
+
warmup_steps = int(total_steps * warmup_percent)
|
268 |
+
|
269 |
+
model = PassageRanker(gpu=gpu, batch_size=batch_size)
|
270 |
+
|
271 |
+
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-6, correct_bias=False)
|
272 |
+
scheduler = get_scheduler(optimizer, scheduler_setting, warmup_steps=warmup_steps, t_total=total_steps)
|
273 |
+
if fp16:
|
274 |
+
try:
|
275 |
+
from apex import amp
|
276 |
+
except ImportError:
|
277 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
278 |
+
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
279 |
+
|
280 |
+
best_acc = -1.
|
281 |
+
for ep in range(epoch_num):
|
282 |
+
logging.info('\n=====epoch {}/{}====='.format(ep, epoch_num))
|
283 |
+
best_acc = train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
|
284 |
+
max_grad_norm, best_acc, model_save_path)
|
models/sparse_retriever.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
import re
|
5 |
+
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from nltk import WordNetLemmatizer, pos_tag
|
8 |
+
from nltk.corpus import wordnet, stopwords
|
9 |
+
from typing import List
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
lemmatizer = WordNetLemmatizer()
|
13 |
+
stopwords = set(stopwords.words('english'))
|
14 |
+
|
15 |
+
|
16 |
+
def get_ngrams(text_tokens: List[str], min_length=1, max_length=4) -> List[str]:
|
17 |
+
"""
|
18 |
+
Gets word-level ngrams from text
|
19 |
+
:param text_tokens: the string used to generate ngrams
|
20 |
+
:param min_length: the minimum length og the generated ngrams in words
|
21 |
+
:param max_length: the maximum length og the generated ngrams in words
|
22 |
+
:return: list of ngrams (strings)
|
23 |
+
"""
|
24 |
+
max_length = min(max_length, len(text_tokens))
|
25 |
+
all_ngrams = []
|
26 |
+
for n in range(min_length - 1, max_length):
|
27 |
+
ngrams = [" ".join(ngram) for ngram in zip(*[text_tokens[i:] for i in range(n + 1)])]
|
28 |
+
for ngram in ngrams:
|
29 |
+
if '#' not in ngram:
|
30 |
+
all_ngrams.append(ngram)
|
31 |
+
|
32 |
+
return all_ngrams
|
33 |
+
|
34 |
+
|
35 |
+
def clean_text(text):
|
36 |
+
"""
|
37 |
+
Removes non-alphanumeric symbols from text
|
38 |
+
:param text:
|
39 |
+
:return: clean text
|
40 |
+
"""
|
41 |
+
text = text.replace('-', ' ')
|
42 |
+
text = re.sub('[^a-zA-Z0-9 ]+', '', text)
|
43 |
+
text = re.sub(' +', ' ', text)
|
44 |
+
return text
|
45 |
+
|
46 |
+
|
47 |
+
def string_hash(string):
|
48 |
+
"""
|
49 |
+
Returns a static hash value for a string
|
50 |
+
:param string:
|
51 |
+
:return:
|
52 |
+
"""
|
53 |
+
return int(hashlib.md5(string.encode('utf8')).hexdigest(), 16)
|
54 |
+
|
55 |
+
|
56 |
+
def tokenize(text, lemmatize=True, ngrams_length=2):
|
57 |
+
"""
|
58 |
+
:param text:
|
59 |
+
:param stopwords: set of stopwords to exclude
|
60 |
+
:param lemmatize:
|
61 |
+
:param ngrams_length: the maximum number of tokens per ngram
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
tokens = clean_text(text).lower().split(' ')
|
65 |
+
tokens = [t for t in tokens if t != '']
|
66 |
+
if lemmatize:
|
67 |
+
lemmatized_words = []
|
68 |
+
pos_labels = pos_tag(tokens)
|
69 |
+
pos_labels = [pos[1][0].lower() for pos in pos_labels]
|
70 |
+
|
71 |
+
for i, word in enumerate(tokens):
|
72 |
+
if word in stopwords:
|
73 |
+
word = '#'
|
74 |
+
if pos_labels[i] == 'j':
|
75 |
+
pos_labels[i] = 'a' # 'j' <--> 'a' reassignment
|
76 |
+
if pos_labels[i] in ['r']: # For adverbs it's a bit different
|
77 |
+
try:
|
78 |
+
lemma = wordnet.synset(word + '.r.1').lemmas()[0].pertainyms()[0].name()
|
79 |
+
except:
|
80 |
+
lemma = word
|
81 |
+
lemmatized_words.append(lemma)
|
82 |
+
elif pos_labels[i] in ['a', 's', 'v']: # For adjectives and verbs
|
83 |
+
lemmatized_words.append(lemmatizer.lemmatize(word, pos=pos_labels[i]))
|
84 |
+
else: # For nouns and everything else as it is the default kwarg
|
85 |
+
lemmatized_words.append(lemmatizer.lemmatize(word))
|
86 |
+
tokens = lemmatized_words
|
87 |
+
|
88 |
+
ngrams = get_ngrams(tokens, max_length=ngrams_length)
|
89 |
+
|
90 |
+
return ngrams
|
91 |
+
|
92 |
+
|
93 |
+
class SparseRetriever:
|
94 |
+
def __init__(self, ngram_buckets=16777216, k1=1.5, b=0.75, epsilon=0.25,
|
95 |
+
max_relative_freq=0.5, workers=4):
|
96 |
+
self.ngram_buckets = ngram_buckets
|
97 |
+
self.k1 = k1
|
98 |
+
self.b = b
|
99 |
+
self.epsilon = epsilon
|
100 |
+
self.max_relative_freq = max_relative_freq
|
101 |
+
self.corpus_size = 0
|
102 |
+
self.avgdl = 0
|
103 |
+
|
104 |
+
self.inverted_index = {}
|
105 |
+
self.idf = {}
|
106 |
+
self.doc_len = []
|
107 |
+
self.workers = workers
|
108 |
+
|
109 |
+
def index_documents(self, documents):
|
110 |
+
with Pool(self.workers) as p:
|
111 |
+
tokenized_documents = list(tqdm(p.imap(tokenize, documents), total=len(documents), desc='tokenized'))
|
112 |
+
|
113 |
+
logging.info('Building inverted index...')
|
114 |
+
|
115 |
+
self.corpus_size = len(tokenized_documents)
|
116 |
+
nd = self._create_inverted_index(tokenized_documents)
|
117 |
+
self._calc_idf(nd)
|
118 |
+
|
119 |
+
logging.info('Built inverted index')
|
120 |
+
|
121 |
+
def _create_inverted_index(self, documents):
|
122 |
+
nd = {} # word -> number of documents with word
|
123 |
+
num_doc = 0
|
124 |
+
for doc_id, document in enumerate(tqdm(documents, desc='indexed')):
|
125 |
+
self.doc_len.append(len(document))
|
126 |
+
num_doc += len(document)
|
127 |
+
|
128 |
+
frequencies = {}
|
129 |
+
for word in document:
|
130 |
+
if word not in frequencies:
|
131 |
+
frequencies[word] = 0
|
132 |
+
frequencies[word] += 1
|
133 |
+
|
134 |
+
for word, freq in frequencies.items():
|
135 |
+
hashed_word = string_hash(word) % self.ngram_buckets
|
136 |
+
if hashed_word not in self.inverted_index:
|
137 |
+
self.inverted_index[hashed_word] = [(doc_id, freq)]
|
138 |
+
else:
|
139 |
+
self.inverted_index[hashed_word].append((doc_id, freq))
|
140 |
+
|
141 |
+
for word, freq in frequencies.items():
|
142 |
+
hashed_word = string_hash(word) % self.ngram_buckets
|
143 |
+
if hashed_word not in nd:
|
144 |
+
nd[hashed_word] = 0
|
145 |
+
nd[hashed_word] += 1
|
146 |
+
|
147 |
+
self.avgdl = num_doc / self.corpus_size
|
148 |
+
|
149 |
+
return nd
|
150 |
+
|
151 |
+
def _calc_idf(self, nd):
|
152 |
+
"""
|
153 |
+
Calculates frequencies of terms in documents and in corpus.
|
154 |
+
This algorithm sets a floor on the idf values to eps * average_idf
|
155 |
+
"""
|
156 |
+
# collect idf sum to calculate an average idf for epsilon value
|
157 |
+
idf_sum = 0
|
158 |
+
# collect words with negative idf to set them a special epsilon value.
|
159 |
+
# idf can be negative if word is contained in more than half of documents
|
160 |
+
negative_idfs = []
|
161 |
+
for word, freq in nd.items():
|
162 |
+
idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
|
163 |
+
r_freq = float(freq) / self.corpus_size
|
164 |
+
if r_freq > self.max_relative_freq:
|
165 |
+
continue
|
166 |
+
self.idf[word] = idf
|
167 |
+
idf_sum += idf
|
168 |
+
if idf < 0:
|
169 |
+
negative_idfs.append(word)
|
170 |
+
self.average_idf = idf_sum / len(self.idf)
|
171 |
+
|
172 |
+
eps = self.epsilon * self.average_idf
|
173 |
+
for word in negative_idfs:
|
174 |
+
self.idf[word] = eps
|
175 |
+
|
176 |
+
def _get_scores(self, query):
|
177 |
+
"""
|
178 |
+
The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores,
|
179 |
+
this algorithm also adds a floor to the idf value of epsilon.
|
180 |
+
See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info
|
181 |
+
:param query:
|
182 |
+
:return:
|
183 |
+
"""
|
184 |
+
query = tokenize(query)
|
185 |
+
scores = {}
|
186 |
+
for q in query:
|
187 |
+
hashed_word = string_hash(q) % self.ngram_buckets
|
188 |
+
idf = self.idf.get(hashed_word)
|
189 |
+
if idf:
|
190 |
+
doc_freqs = self.inverted_index[hashed_word]
|
191 |
+
for doc_id, freq in doc_freqs:
|
192 |
+
score = idf * (freq * (self.k1 + 1) /
|
193 |
+
(freq + self.k1 * 1 - self.b + (self.b * self.doc_len[doc_id] / self.avgdl)))
|
194 |
+
if doc_id in scores:
|
195 |
+
scores[doc_id] += score
|
196 |
+
else:
|
197 |
+
scores[doc_id] = score
|
198 |
+
return sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
199 |
+
|
200 |
+
def search(self, queries, topk=100):
|
201 |
+
results = [self._get_scores(q) for q in tqdm(queries, desc='searched')]
|
202 |
+
results = [r[:topk] for r in results]
|
203 |
+
logging.info('Done searching')
|
204 |
+
return results
|
models/sparse_retriever_fast.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import tantivy
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class SparseRetrieverFast:
|
9 |
+
def __init__(self, path='sparse_index', load=True):
|
10 |
+
if not os.path.exists(path):
|
11 |
+
os.mkdir(path)
|
12 |
+
schema_builder = tantivy.SchemaBuilder()
|
13 |
+
schema_builder.add_text_field("body", stored=False)
|
14 |
+
schema_builder.add_unsigned_field("doc_id", stored=True)
|
15 |
+
schema = schema_builder.build()
|
16 |
+
self.index = tantivy.Index(schema, path=path, reuse=load)
|
17 |
+
self.searcher = self.index.searcher()
|
18 |
+
|
19 |
+
def index_documents(self, documents):
|
20 |
+
logging.info('Building sparse index of {} docs...'.format(len(documents)))
|
21 |
+
writer = self.index.writer()
|
22 |
+
for i, doc in enumerate(documents):
|
23 |
+
writer.add_document(tantivy.Document(
|
24 |
+
body=[doc],
|
25 |
+
doc_id=i
|
26 |
+
))
|
27 |
+
if (i+1) % 100000 == 0:
|
28 |
+
writer.commit()
|
29 |
+
logging.info('Indexed {} docs'.format(i+1))
|
30 |
+
writer.commit()
|
31 |
+
logging.info('Built sparse index')
|
32 |
+
self.index.reload()
|
33 |
+
self.searcher = self.index.searcher()
|
34 |
+
|
35 |
+
def search(self, queries, topk=100):
|
36 |
+
results = []
|
37 |
+
for q in tqdm(queries, desc='searched'):
|
38 |
+
docs = []
|
39 |
+
try:
|
40 |
+
query = self.index.parse_query(q, ["body"])
|
41 |
+
scores = self.searcher.search(query, topk).hits
|
42 |
+
docs = [(self.searcher.doc(doc_id)['doc_id'][0], score)
|
43 |
+
for score, doc_id in scores]
|
44 |
+
except:
|
45 |
+
pass
|
46 |
+
results.append(docs)
|
47 |
+
|
48 |
+
return results
|
models/text_encoder.py
ADDED
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import hashlib
|
7 |
+
import pickle
|
8 |
+
import torch
|
9 |
+
torch.device(0)
|
10 |
+
import transformers
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from pprint import pprint
|
15 |
+
from typing import Iterable, Tuple, Type, Dict, List, Union, Optional
|
16 |
+
from numpy.core.multiarray import ndarray
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import DataParallel
|
19 |
+
from torch.optim import Optimizer
|
20 |
+
from torch.utils.data import DataLoader, Dataset
|
21 |
+
|
22 |
+
from models.data_utils import get_qnli_examples, get_single_examples, get_ict_examples, get_examples, get_qar_examples, \
|
23 |
+
get_qar_artificial_examples, get_retrieval_examples
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from collections import OrderedDict
|
27 |
+
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
|
28 |
+
|
29 |
+
from models.vector_index import VectorIndex
|
30 |
+
|
31 |
+
|
32 |
+
class InputExample:
|
33 |
+
"""
|
34 |
+
Structure for one input example with texts, the label and a unique id
|
35 |
+
"""
|
36 |
+
def __init__(self, guid: str, texts: List[str], label: Union[int, float]):
|
37 |
+
"""
|
38 |
+
Creates one InputExample with the given texts, guid and label
|
39 |
+
|
40 |
+
str.strip() is called on both texts.
|
41 |
+
|
42 |
+
:param guid
|
43 |
+
id for the example
|
44 |
+
:param texts
|
45 |
+
the texts for the example
|
46 |
+
:param label
|
47 |
+
the label for the example
|
48 |
+
"""
|
49 |
+
self.guid = guid
|
50 |
+
self.texts = [text.strip() for text in texts]
|
51 |
+
self.label = label
|
52 |
+
|
53 |
+
|
54 |
+
class LoggingHandler(logging.Handler):
|
55 |
+
def __init__(self, level=logging.NOTSET):
|
56 |
+
super().__init__(level)
|
57 |
+
|
58 |
+
def emit(self, record):
|
59 |
+
try:
|
60 |
+
msg = self.format(record)
|
61 |
+
tqdm.write(msg)
|
62 |
+
self.flush()
|
63 |
+
except (KeyboardInterrupt, SystemExit):
|
64 |
+
raise
|
65 |
+
except:
|
66 |
+
self.handleError(record)
|
67 |
+
|
68 |
+
|
69 |
+
def import_from_string(dotted_path):
|
70 |
+
"""
|
71 |
+
Import a dotted module path and return the attribute/class designated by the
|
72 |
+
last name in the path. Raise ImportError if the import failed.
|
73 |
+
"""
|
74 |
+
try:
|
75 |
+
module_path, class_name = dotted_path.rsplit('.', 1)
|
76 |
+
except ValueError:
|
77 |
+
msg = "%s doesn't look like a module path" % dotted_path
|
78 |
+
raise ImportError(msg)
|
79 |
+
|
80 |
+
module = importlib.import_module(module_path)
|
81 |
+
|
82 |
+
try:
|
83 |
+
return getattr(module, class_name)
|
84 |
+
except AttributeError:
|
85 |
+
msg = 'Module "%s" does not define a "%s" attribute/class' % (module_path, class_name)
|
86 |
+
raise ImportError(msg)
|
87 |
+
|
88 |
+
|
89 |
+
def batch_to_device(batch, target_device: torch.device):
|
90 |
+
"""
|
91 |
+
send a batch to a device
|
92 |
+
|
93 |
+
:param batch:
|
94 |
+
:param target_device:
|
95 |
+
:return: the batch sent to the device
|
96 |
+
"""
|
97 |
+
features = batch['features']
|
98 |
+
for paired_sentence_idx in range(len(features)):
|
99 |
+
for feature_name in features[paired_sentence_idx]:
|
100 |
+
features[paired_sentence_idx][feature_name] = features[paired_sentence_idx][feature_name].to(target_device)
|
101 |
+
|
102 |
+
labels = batch['labels'].to(target_device)
|
103 |
+
return features, labels
|
104 |
+
|
105 |
+
|
106 |
+
class BERT(nn.Module):
|
107 |
+
"""BERT model to generate token embeddings.
|
108 |
+
Each token is mapped to an output vector from BERT.
|
109 |
+
"""
|
110 |
+
def __init__(self, model_name_or_path: str, max_seq_length: int = 128, do_lower_case: Optional[bool] = None, model_args: Dict = {}, tokenizer_args: Dict = {}):
|
111 |
+
super(BERT, self).__init__()
|
112 |
+
self.config_keys = ['max_seq_length', 'do_lower_case']
|
113 |
+
self.do_lower_case = do_lower_case
|
114 |
+
|
115 |
+
if max_seq_length > 510:
|
116 |
+
logging.warning("BERT only allows a max_seq_length of 510 (512 with special tokens). Value will be set to 510")
|
117 |
+
max_seq_length = 510
|
118 |
+
self.max_seq_length = max_seq_length
|
119 |
+
|
120 |
+
if self.do_lower_case is not None:
|
121 |
+
tokenizer_args['do_lower_case'] = do_lower_case
|
122 |
+
|
123 |
+
self.model = AutoModel.from_pretrained(model_name_or_path, **model_args)
|
124 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_args)
|
125 |
+
|
126 |
+
def forward(self, features):
|
127 |
+
"""Returns token_embeddings, cls_token"""
|
128 |
+
output_states = self.model(**features)
|
129 |
+
output_tokens = output_states[0]
|
130 |
+
cls_tokens = output_tokens[:, 0, :] # CLS token is first token
|
131 |
+
features.update({'token_embeddings': output_tokens, 'cls_token_embeddings': cls_tokens, 'attention_mask': features['attention_mask']})
|
132 |
+
|
133 |
+
if len(output_states) > 2:
|
134 |
+
features.update({'all_layer_embeddings': output_states[2]})
|
135 |
+
|
136 |
+
return features
|
137 |
+
|
138 |
+
def get_word_embedding_dimension(self) -> int:
|
139 |
+
return self.model.config.hidden_size
|
140 |
+
|
141 |
+
def tokenize(self, text: str) -> List[int]:
|
142 |
+
"""
|
143 |
+
Tokenizes a text and maps tokens to token-ids
|
144 |
+
"""
|
145 |
+
return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
|
146 |
+
|
147 |
+
def get_sentence_features(self, tokens: List[int], pad_seq_length: int):
|
148 |
+
"""
|
149 |
+
Convert tokenized sentence in its embedding ids, segment ids and mask
|
150 |
+
:param tokens:
|
151 |
+
a tokenized sentence
|
152 |
+
:param pad_seq_length:
|
153 |
+
the maximal length of the sequence. Cannot be greater than self.sentence_transformer_config.max_seq_length
|
154 |
+
:return: embedding ids, segment ids and mask for the sentence
|
155 |
+
"""
|
156 |
+
pad_seq_length = min(pad_seq_length, self.max_seq_length) + 2 ##Add Space for CLS + SEP token
|
157 |
+
|
158 |
+
return self.tokenizer.prepare_for_model(tokens, max_length=pad_seq_length, pad_to_max_length=True, return_tensors='pt')
|
159 |
+
|
160 |
+
def get_config_dict(self):
|
161 |
+
return {key: self.__dict__[key] for key in self.config_keys}
|
162 |
+
|
163 |
+
def save(self, output_path: str):
|
164 |
+
self.model.save_pretrained(output_path)
|
165 |
+
self.tokenizer.save_pretrained(output_path)
|
166 |
+
|
167 |
+
with open(os.path.join(output_path, 'sentence_bert_config.json'), 'w') as fOut:
|
168 |
+
json.dump(self.get_config_dict(), fOut, indent=2)
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def load(input_path: str):
|
172 |
+
with open(os.path.join(input_path, 'sentence_bert_config.json')) as fIn:
|
173 |
+
config = json.load(fIn)
|
174 |
+
return BERT(model_name_or_path=input_path, **config)
|
175 |
+
|
176 |
+
|
177 |
+
class SentenceTransformer(nn.Sequential):
|
178 |
+
def __init__(self, model_path: str = None, modules: Iterable[nn.Module] = None, device: str = None,
|
179 |
+
parallel=False):
|
180 |
+
if modules is not None and not isinstance(modules, OrderedDict):
|
181 |
+
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
|
182 |
+
|
183 |
+
if model_path is not None:
|
184 |
+
logging.info("Load SentenceTransformer from folder: {}".format(model_path))
|
185 |
+
|
186 |
+
with open(os.path.join(model_path, 'modules.json')) as fIn:
|
187 |
+
contained_modules = json.load(fIn)
|
188 |
+
|
189 |
+
modules = OrderedDict()
|
190 |
+
for module_config in contained_modules:
|
191 |
+
module_class = import_from_string(module_config['type'])
|
192 |
+
module = module_class.load(os.path.join(model_path, module_config['path']))
|
193 |
+
if 'BERT' in module_config['type']:
|
194 |
+
if parallel:
|
195 |
+
module = DataParallel(module)
|
196 |
+
modules[module_config['name']] = module
|
197 |
+
|
198 |
+
super().__init__(modules)
|
199 |
+
|
200 |
+
self.best_score = -1
|
201 |
+
self.total_steps = 0
|
202 |
+
|
203 |
+
if device is None:
|
204 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
205 |
+
logging.info("Use pytorch device: {}".format(device))
|
206 |
+
self.device = torch.device(device)
|
207 |
+
self.to(device)
|
208 |
+
|
209 |
+
def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps):
|
210 |
+
"""Runs evaluation during the training"""
|
211 |
+
if evaluator is not None:
|
212 |
+
score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
|
213 |
+
print(score)
|
214 |
+
if score > self.best_score and save_best_model:
|
215 |
+
print('saving')
|
216 |
+
self.save(output_path)
|
217 |
+
self.best_score = score
|
218 |
+
|
219 |
+
def evaluate(self, evaluator, output_path: str = None):
|
220 |
+
"""
|
221 |
+
Evaluate the model
|
222 |
+
|
223 |
+
:param evaluator:
|
224 |
+
the evaluator
|
225 |
+
:param output_path:
|
226 |
+
the evaluator can write the results to this path
|
227 |
+
"""
|
228 |
+
if output_path is not None:
|
229 |
+
os.makedirs(output_path, exist_ok=True)
|
230 |
+
return evaluator(self, output_path)
|
231 |
+
|
232 |
+
def encode(self, sentences: List[str], batch_size: int = 8, show_progress_bar: bool = None, output_value: str = 'sentence_embedding', convert_to_numpy: bool = True) -> List[ndarray]:
|
233 |
+
"""
|
234 |
+
Computes sentence embeddings
|
235 |
+
:param sentences:
|
236 |
+
the sentences to embed
|
237 |
+
:param batch_size:
|
238 |
+
the batch size used for the computation
|
239 |
+
:param show_progress_bar:
|
240 |
+
Output a progress bar when encode sentences
|
241 |
+
:param output_value:
|
242 |
+
Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings
|
243 |
+
to get wordpiece token embeddings.
|
244 |
+
:param convert_to_numpy:
|
245 |
+
If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
|
246 |
+
:return:
|
247 |
+
Depending on convert_to_numpy, either a list of numpy vectors or a list of pytorch tensors
|
248 |
+
"""
|
249 |
+
self.eval()
|
250 |
+
if show_progress_bar is None:
|
251 |
+
show_progress_bar = (logging.getLogger().getEffectiveLevel()==logging.INFO or logging.getLogger().getEffectiveLevel()==logging.DEBUG)
|
252 |
+
|
253 |
+
all_embeddings = []
|
254 |
+
length_sorted_idx = np.argsort([len(sen) for sen in sentences])
|
255 |
+
|
256 |
+
iterator = range(0, len(sentences), batch_size)
|
257 |
+
if show_progress_bar:
|
258 |
+
iterator = tqdm(iterator, desc="Batches")
|
259 |
+
|
260 |
+
for batch_idx in iterator:
|
261 |
+
batch_tokens = []
|
262 |
+
|
263 |
+
batch_start = batch_idx
|
264 |
+
batch_end = min(batch_start + batch_size, len(sentences))
|
265 |
+
|
266 |
+
longest_seq = 0
|
267 |
+
|
268 |
+
for idx in length_sorted_idx[batch_start: batch_end]:
|
269 |
+
sentence = sentences[idx]
|
270 |
+
tokens = self.tokenize(sentence)
|
271 |
+
longest_seq = max(longest_seq, len(tokens))
|
272 |
+
batch_tokens.append(tokens)
|
273 |
+
|
274 |
+
features = {}
|
275 |
+
for text in batch_tokens:
|
276 |
+
sentence_features = self.get_sentence_features(text, longest_seq)
|
277 |
+
|
278 |
+
for feature_name in sentence_features:
|
279 |
+
if feature_name not in features:
|
280 |
+
features[feature_name] = []
|
281 |
+
features[feature_name].append(sentence_features[feature_name])
|
282 |
+
|
283 |
+
for feature_name in features:
|
284 |
+
try:
|
285 |
+
features[feature_name] = torch.tensor(np.asarray(features[feature_name])).to(self.device)
|
286 |
+
except:
|
287 |
+
features[feature_name] = torch.cat(features[feature_name]).to(self.device)
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
out_features = self.forward(features)
|
291 |
+
embeddings = out_features[output_value]
|
292 |
+
|
293 |
+
if output_value == 'token_embeddings':
|
294 |
+
#Set token embeddings to 0 for padding tokens
|
295 |
+
input_mask = out_features['attention_mask']
|
296 |
+
input_mask_expanded = input_mask.unsqueeze(-1).expand(embeddings.size()).float()
|
297 |
+
embeddings = embeddings * input_mask_expanded
|
298 |
+
|
299 |
+
if convert_to_numpy:
|
300 |
+
embeddings = embeddings.to('cpu').numpy()
|
301 |
+
|
302 |
+
all_embeddings.extend(embeddings)
|
303 |
+
|
304 |
+
reverting_order = np.argsort(length_sorted_idx)
|
305 |
+
all_embeddings = [all_embeddings[idx] for idx in reverting_order]
|
306 |
+
|
307 |
+
return all_embeddings
|
308 |
+
|
309 |
+
def get_max_seq_length(self):
|
310 |
+
if hasattr(self._first_module(), 'max_seq_length'):
|
311 |
+
return self._first_module().max_seq_length
|
312 |
+
|
313 |
+
return None
|
314 |
+
|
315 |
+
def tokenize(self, text):
|
316 |
+
return self._first_module().tokenize(text)
|
317 |
+
|
318 |
+
def get_sentence_features(self, *features):
|
319 |
+
return self._first_module().get_sentence_features(*features)
|
320 |
+
|
321 |
+
def get_sentence_embedding_dimension(self):
|
322 |
+
return self._last_module().get_sentence_embedding_dimension()
|
323 |
+
|
324 |
+
def _first_module(self):
|
325 |
+
"""Returns the first module of this sequential embedder"""
|
326 |
+
try:
|
327 |
+
return self._modules[next(iter(self._modules))].module
|
328 |
+
except:
|
329 |
+
return self._modules[next(iter(self._modules))]
|
330 |
+
|
331 |
+
def _last_module(self):
|
332 |
+
"""Returns the last module of this sequential embedder"""
|
333 |
+
return self._modules[next(reversed(self._modules))]
|
334 |
+
|
335 |
+
def save(self, path):
|
336 |
+
"""
|
337 |
+
Saves all elements for this seq. sentence embedder into different sub-folders
|
338 |
+
"""
|
339 |
+
if path is None:
|
340 |
+
return
|
341 |
+
|
342 |
+
logging.info("Save model to {}".format(path))
|
343 |
+
contained_modules = []
|
344 |
+
|
345 |
+
for idx, name in enumerate(self._modules):
|
346 |
+
module = self._modules[name]
|
347 |
+
if isinstance(module, DataParallel):
|
348 |
+
module = module.module
|
349 |
+
model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)
|
350 |
+
os.makedirs(model_path, exist_ok=True)
|
351 |
+
module.save(model_path)
|
352 |
+
contained_modules.append(
|
353 |
+
{'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__})
|
354 |
+
|
355 |
+
with open(os.path.join(path, 'modules.json'), 'w') as fOut:
|
356 |
+
json.dump(contained_modules, fOut, indent=2)
|
357 |
+
|
358 |
+
with open(os.path.join(path, 'config.json'), 'w') as fOut:
|
359 |
+
json.dump({'__version__': '1.0'}, fOut, indent=2)
|
360 |
+
|
361 |
+
def smart_batching_collate(self, batch):
|
362 |
+
"""
|
363 |
+
Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
|
364 |
+
:param batch:
|
365 |
+
a batch from a SmartBatchingDataset
|
366 |
+
:return:
|
367 |
+
a batch of tensors for the model
|
368 |
+
"""
|
369 |
+
num_texts = len(batch[0][0])
|
370 |
+
|
371 |
+
labels = []
|
372 |
+
paired_texts = [[] for _ in range(num_texts)]
|
373 |
+
max_seq_len = [0] * num_texts
|
374 |
+
for tokens, label in batch:
|
375 |
+
labels.append(label)
|
376 |
+
for i in range(num_texts):
|
377 |
+
paired_texts[i].append(tokens[i])
|
378 |
+
max_seq_len[i] = max(max_seq_len[i], len(tokens[i]))
|
379 |
+
|
380 |
+
features = []
|
381 |
+
for idx in range(num_texts):
|
382 |
+
max_len = max_seq_len[idx]
|
383 |
+
feature_lists = {}
|
384 |
+
|
385 |
+
for text in paired_texts[idx]:
|
386 |
+
sentence_features = self.get_sentence_features(text, max_len)
|
387 |
+
|
388 |
+
for feature_name in sentence_features:
|
389 |
+
if feature_name not in feature_lists:
|
390 |
+
feature_lists[feature_name] = []
|
391 |
+
|
392 |
+
feature_lists[feature_name].append(sentence_features[feature_name])
|
393 |
+
|
394 |
+
for feature_name in feature_lists:
|
395 |
+
try:
|
396 |
+
feature_lists[feature_name] = torch.tensor(np.asarray(feature_lists[feature_name]))
|
397 |
+
except:
|
398 |
+
feature_lists[feature_name] = torch.cat(feature_lists[feature_name])
|
399 |
+
|
400 |
+
features.append(feature_lists)
|
401 |
+
|
402 |
+
return {'features': features, 'labels': torch.stack(labels)}
|
403 |
+
|
404 |
+
def fit(self,
|
405 |
+
train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
|
406 |
+
evaluator,
|
407 |
+
epochs: int = 1,
|
408 |
+
steps_per_epoch=None,
|
409 |
+
scheduler: str = 'WarmupLinear',
|
410 |
+
warmup_steps: int = 10000,
|
411 |
+
optimizer_class: Type[Optimizer] = transformers.AdamW,
|
412 |
+
optimizer_params: Dict[str, object] = {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
|
413 |
+
acc_steps: int = 4,
|
414 |
+
weight_decay: float = 0.01,
|
415 |
+
evaluation_steps: int = 0,
|
416 |
+
output_path: str = None,
|
417 |
+
save_best_model: bool = True,
|
418 |
+
max_grad_norm: float = 1,
|
419 |
+
fp16: bool = False,
|
420 |
+
fp16_opt_level: str = 'O1',
|
421 |
+
local_rank: int = -1
|
422 |
+
):
|
423 |
+
if output_path is not None:
|
424 |
+
os.makedirs(output_path, exist_ok=True)
|
425 |
+
|
426 |
+
dataloaders = [dataloader for dataloader, _ in train_objectives]
|
427 |
+
|
428 |
+
# Use smart batching
|
429 |
+
for dataloader in dataloaders:
|
430 |
+
dataloader.collate_fn = self.smart_batching_collate
|
431 |
+
|
432 |
+
loss_models = [loss for _, loss in train_objectives]
|
433 |
+
device = self.device
|
434 |
+
|
435 |
+
for loss_model in loss_models:
|
436 |
+
loss_model.to(device)
|
437 |
+
|
438 |
+
if steps_per_epoch is None or steps_per_epoch == 0:
|
439 |
+
steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])
|
440 |
+
|
441 |
+
num_train_steps = int(steps_per_epoch * epochs)
|
442 |
+
|
443 |
+
# Prepare optimizers
|
444 |
+
optimizers = []
|
445 |
+
schedulers = []
|
446 |
+
for loss_model in loss_models:
|
447 |
+
param_optimizer = list(loss_model.named_parameters())
|
448 |
+
|
449 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
450 |
+
optimizer_grouped_parameters = [
|
451 |
+
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
452 |
+
'weight_decay': weight_decay},
|
453 |
+
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
454 |
+
]
|
455 |
+
t_total = num_train_steps
|
456 |
+
if local_rank != -1:
|
457 |
+
t_total = t_total // torch.distributed.get_world_size()
|
458 |
+
|
459 |
+
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
|
460 |
+
scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps,
|
461 |
+
t_total=t_total)
|
462 |
+
|
463 |
+
optimizers.append(optimizer)
|
464 |
+
schedulers.append(scheduler_obj)
|
465 |
+
|
466 |
+
if fp16:
|
467 |
+
try:
|
468 |
+
from apex import amp
|
469 |
+
except ImportError:
|
470 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
471 |
+
|
472 |
+
for train_idx in range(len(loss_models)):
|
473 |
+
model, optimizer = amp.initialize(loss_models[train_idx], optimizers[train_idx],
|
474 |
+
opt_level=fp16_opt_level)
|
475 |
+
loss_models[train_idx] = model
|
476 |
+
optimizers[train_idx] = optimizer
|
477 |
+
|
478 |
+
data_iterators = [iter(dataloader) for dataloader in dataloaders]
|
479 |
+
num_train_objectives = len(train_objectives)
|
480 |
+
|
481 |
+
for epoch in range(epochs):
|
482 |
+
#logging.info('Epoch {}'.format(epoch))
|
483 |
+
epoch_loss = 0
|
484 |
+
epoch_steps = 0
|
485 |
+
for loss_model in loss_models:
|
486 |
+
loss_model.zero_grad()
|
487 |
+
loss_model.train()
|
488 |
+
|
489 |
+
step_iterator = tqdm(range(steps_per_epoch), desc='loss: -')
|
490 |
+
for step in step_iterator:
|
491 |
+
for train_idx in range(num_train_objectives):
|
492 |
+
loss_model = loss_models[train_idx]
|
493 |
+
optimizer = optimizers[train_idx]
|
494 |
+
scheduler = schedulers[train_idx]
|
495 |
+
data_iterator = data_iterators[train_idx]
|
496 |
+
|
497 |
+
try:
|
498 |
+
data = next(data_iterator)
|
499 |
+
except StopIteration:
|
500 |
+
# logging.info("Restart data_iterator")
|
501 |
+
data_iterator = iter(dataloaders[train_idx])
|
502 |
+
data_iterators[train_idx] = data_iterator
|
503 |
+
data = next(data_iterator)
|
504 |
+
|
505 |
+
features, labels = batch_to_device(data, self.device)
|
506 |
+
loss_value = loss_model(features, labels)
|
507 |
+
loss_value = loss_value / acc_steps
|
508 |
+
|
509 |
+
if fp16:
|
510 |
+
with amp.scale_loss(loss_value, optimizer) as scaled_loss:
|
511 |
+
scaled_loss.backward()
|
512 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
|
513 |
+
else:
|
514 |
+
loss_value.backward()
|
515 |
+
torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
|
516 |
+
|
517 |
+
if (step + 1) % acc_steps == 0:
|
518 |
+
optimizer.step()
|
519 |
+
scheduler.step()
|
520 |
+
optimizer.zero_grad()
|
521 |
+
|
522 |
+
self.total_steps += 1
|
523 |
+
|
524 |
+
if (step + 1) % acc_steps == 0:
|
525 |
+
epoch_steps += 1
|
526 |
+
epoch_loss += loss_value.item()
|
527 |
+
step_iterator.set_description('loss: {} - acc steps: {}'.format((epoch_loss / epoch_steps),
|
528 |
+
(self.total_steps / acc_steps)))
|
529 |
+
|
530 |
+
if evaluation_steps > 0 and self.total_steps > 0 and (self.total_steps / acc_steps) % evaluation_steps == 0:
|
531 |
+
self._eval_during_training(evaluator, output_path, save_best_model, epoch, epoch_steps)
|
532 |
+
for loss_model in loss_models:
|
533 |
+
loss_model.zero_grad()
|
534 |
+
loss_model.train()
|
535 |
+
|
536 |
+
self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1)
|
537 |
+
|
538 |
+
def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int, t_total: int):
|
539 |
+
"""
|
540 |
+
Returns the correct learning rate scheduler
|
541 |
+
"""
|
542 |
+
scheduler = scheduler.lower()
|
543 |
+
if scheduler == 'constantlr':
|
544 |
+
return transformers.get_constant_schedule(optimizer)
|
545 |
+
elif scheduler == 'warmupconstant':
|
546 |
+
return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
|
547 |
+
elif scheduler == 'warmuplinear':
|
548 |
+
return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
|
549 |
+
elif scheduler == 'warmupcosine':
|
550 |
+
return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
|
551 |
+
elif scheduler == 'warmupcosinewithhardrestarts':
|
552 |
+
return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
|
553 |
+
else:
|
554 |
+
raise ValueError("Unknown scheduler {}".format(scheduler))
|
555 |
+
|
556 |
+
|
557 |
+
class SentencesDataset(Dataset):
|
558 |
+
def __init__(self, examples: List[InputExample], model: SentenceTransformer, show_progress_bar: bool = None):
|
559 |
+
self.examples = examples
|
560 |
+
self.model = model
|
561 |
+
|
562 |
+
def __getitem__(self, item):
|
563 |
+
tokenized_texts = [model.tokenize(text) for text in self.examples[item].texts]
|
564 |
+
return tokenized_texts, torch.tensor(self.examples[item].label, dtype=torch.float)
|
565 |
+
|
566 |
+
def __len__(self):
|
567 |
+
return len(self.examples)
|
568 |
+
|
569 |
+
|
570 |
+
class MultipleNegativesRankingLossANN(nn.Module):
|
571 |
+
def __init__(self, model: SentenceTransformer, negative_samples=4):
|
572 |
+
super(MultipleNegativesRankingLossANN, self).__init__()
|
573 |
+
self.model = model
|
574 |
+
self.negative_samples = negative_samples
|
575 |
+
|
576 |
+
def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
|
577 |
+
embeddings = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
|
578 |
+
return self.multiple_negatives_ranking_loss(embeddings)
|
579 |
+
|
580 |
+
def multiple_negatives_ranking_loss(self, embeddings: List[torch.Tensor]):
|
581 |
+
positive_loss = torch.mean(torch.sum(embeddings[0] * embeddings[1], dim=-1))
|
582 |
+
|
583 |
+
negative_loss = torch.sum(embeddings[0] * embeddings[1], dim=-1)
|
584 |
+
for i in range(2, self.negative_samples + 2):
|
585 |
+
negative_loss = torch.cat((negative_loss, torch.sum(embeddings[0] * embeddings[i], dim=-1)), dim=-1)
|
586 |
+
negative_loss = torch.mean(torch.logsumexp(negative_loss, dim=-1))
|
587 |
+
|
588 |
+
return -positive_loss + negative_loss
|
589 |
+
|
590 |
+
|
591 |
+
class MultipleNegativesRankingLoss(nn.Module):
|
592 |
+
def __init__(self, model: SentenceTransformer):
|
593 |
+
super(MultipleNegativesRankingLoss, self).__init__()
|
594 |
+
self.model = model
|
595 |
+
|
596 |
+
def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
|
597 |
+
reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
|
598 |
+
|
599 |
+
reps_a, reps_b = reps
|
600 |
+
return self.multiple_negatives_ranking_loss(reps_a, reps_b)
|
601 |
+
|
602 |
+
def multiple_negatives_ranking_loss(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor):
|
603 |
+
scores = torch.matmul(embeddings_a, embeddings_b.t())
|
604 |
+
diagonal_mean = torch.mean(torch.diag(scores))
|
605 |
+
mean_log_row_sum_exp = torch.mean(torch.logsumexp(scores, dim=1))
|
606 |
+
return -diagonal_mean + mean_log_row_sum_exp
|
607 |
+
|
608 |
+
|
609 |
+
class Pooling(nn.Module):
|
610 |
+
"""Performs pooling (max or mean) on the token embeddings.
|
611 |
+
|
612 |
+
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
|
613 |
+
You can concatenate multiple poolings together.
|
614 |
+
"""
|
615 |
+
def __init__(self,
|
616 |
+
word_embedding_dimension: int,
|
617 |
+
pooling_mode_cls_token: bool = False,
|
618 |
+
pooling_mode_max_tokens: bool = False,
|
619 |
+
pooling_mode_mean_tokens: bool = True,
|
620 |
+
pooling_mode_mean_sqrt_len_tokens: bool = False,
|
621 |
+
):
|
622 |
+
super(Pooling, self).__init__()
|
623 |
+
|
624 |
+
self.config_keys = ['word_embedding_dimension', 'pooling_mode_cls_token', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens', 'pooling_mode_mean_sqrt_len_tokens']
|
625 |
+
|
626 |
+
self.word_embedding_dimension = word_embedding_dimension
|
627 |
+
self.pooling_mode_cls_token = pooling_mode_cls_token
|
628 |
+
self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
|
629 |
+
self.pooling_mode_max_tokens = pooling_mode_max_tokens
|
630 |
+
self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
|
631 |
+
|
632 |
+
pooling_mode_multiplier = sum([pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, pooling_mode_mean_sqrt_len_tokens])
|
633 |
+
self.pooling_output_dimension = (pooling_mode_multiplier * word_embedding_dimension)
|
634 |
+
|
635 |
+
def forward(self, features: Dict[str, torch.Tensor]):
|
636 |
+
token_embeddings = features['token_embeddings']
|
637 |
+
cls_token = features['cls_token_embeddings']
|
638 |
+
attention_mask = features['attention_mask']
|
639 |
+
|
640 |
+
## Pooling strategy
|
641 |
+
output_vectors = []
|
642 |
+
if self.pooling_mode_cls_token:
|
643 |
+
output_vectors.append(cls_token)
|
644 |
+
if self.pooling_mode_max_tokens:
|
645 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
646 |
+
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
|
647 |
+
max_over_time = torch.max(token_embeddings, 1)[0]
|
648 |
+
output_vectors.append(max_over_time)
|
649 |
+
if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
|
650 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
651 |
+
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
|
652 |
+
|
653 |
+
#If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
|
654 |
+
if 'token_weights_sum' in features:
|
655 |
+
sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size())
|
656 |
+
else:
|
657 |
+
sum_mask = input_mask_expanded.sum(1)
|
658 |
+
|
659 |
+
sum_mask = torch.clamp(sum_mask, min=1e-9)
|
660 |
+
|
661 |
+
if self.pooling_mode_mean_tokens:
|
662 |
+
output_vectors.append(sum_embeddings / sum_mask)
|
663 |
+
if self.pooling_mode_mean_sqrt_len_tokens:
|
664 |
+
output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
|
665 |
+
|
666 |
+
output_vector = torch.cat(output_vectors, 1)
|
667 |
+
features.update({'sentence_embedding': output_vector})
|
668 |
+
return features
|
669 |
+
|
670 |
+
def get_sentence_embedding_dimension(self):
|
671 |
+
return self.pooling_output_dimension
|
672 |
+
|
673 |
+
def get_config_dict(self):
|
674 |
+
return {key: self.__dict__[key] for key in self.config_keys}
|
675 |
+
|
676 |
+
def save(self, output_path):
|
677 |
+
with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
|
678 |
+
json.dump(self.get_config_dict(), fOut, indent=2)
|
679 |
+
|
680 |
+
@staticmethod
|
681 |
+
def load(input_path):
|
682 |
+
with open(os.path.join(input_path, 'config.json')) as fIn:
|
683 |
+
config = json.load(fIn)
|
684 |
+
|
685 |
+
return Pooling(**config)
|
686 |
+
|
687 |
+
|
688 |
+
class RankingEvaluator:
|
689 |
+
def __init__(self, dataloader: DataLoader, random_paragraphs: DataLoader = None,
|
690 |
+
name: str = '', show_progress_bar: bool = None):
|
691 |
+
"""
|
692 |
+
Constructs an evaluator based for the dataset
|
693 |
+
|
694 |
+
The labels need to indicate the similarity between the sentences.
|
695 |
+
|
696 |
+
:param dataloader:
|
697 |
+
the data for the evaluation
|
698 |
+
:param main_similarity:
|
699 |
+
the similarity metric that will be used for the returned score
|
700 |
+
"""
|
701 |
+
self.dataloader = dataloader
|
702 |
+
self.name = name
|
703 |
+
if name:
|
704 |
+
name = "_" + name
|
705 |
+
|
706 |
+
if show_progress_bar is None:
|
707 |
+
show_progress_bar = (
|
708 |
+
logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG)
|
709 |
+
self.show_progress_bar = show_progress_bar
|
710 |
+
|
711 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
712 |
+
|
713 |
+
self.random_paragraphs = random_paragraphs
|
714 |
+
|
715 |
+
def __call__(self, model: 'SequentialSentenceEmbedder', output_path: str = None, epoch: int = -1,
|
716 |
+
steps: int = -1) -> float:
|
717 |
+
model.eval()
|
718 |
+
embeddings1 = []
|
719 |
+
embeddings2 = []
|
720 |
+
|
721 |
+
if epoch != -1:
|
722 |
+
if steps == -1:
|
723 |
+
out_txt = f" after epoch {epoch}:"
|
724 |
+
else:
|
725 |
+
out_txt = f" in epoch {epoch} after {steps} steps:"
|
726 |
+
else:
|
727 |
+
out_txt = ":"
|
728 |
+
|
729 |
+
logging.info("Evaluation" + out_txt)
|
730 |
+
|
731 |
+
logging.info('Calculating sentence embeddings...')
|
732 |
+
|
733 |
+
self.dataloader.collate_fn = model.smart_batching_collate
|
734 |
+
iterator = self.dataloader
|
735 |
+
evidence_features = []
|
736 |
+
for step, batch in enumerate(tqdm(iterator, desc='batch')):
|
737 |
+
features, label_ids = batch_to_device(batch, self.device)
|
738 |
+
with torch.no_grad():
|
739 |
+
emb1, emb2 = [model(sent_features)['sentence_embedding'].to("cpu").numpy() for sent_features in
|
740 |
+
features]
|
741 |
+
embeddings1.extend(emb1)
|
742 |
+
embeddings2.extend(emb2)
|
743 |
+
|
744 |
+
input_ids_2 = features[1]['input_ids'].to("cpu").numpy()
|
745 |
+
for f in input_ids_2:
|
746 |
+
evidence_features.append(hashlib.sha1(str(f).encode('utf-8')).hexdigest())
|
747 |
+
|
748 |
+
total_examples = len(embeddings1)
|
749 |
+
|
750 |
+
logging.info('Building index...')
|
751 |
+
|
752 |
+
vindex = VectorIndex(len(embeddings1[0]))
|
753 |
+
for v in embeddings2:
|
754 |
+
vindex.add(v)
|
755 |
+
|
756 |
+
if self.random_paragraphs is not None:
|
757 |
+
logging.info('Calculating random wikipedia paragraph embeddings...')
|
758 |
+
|
759 |
+
self.random_paragraphs.collate_fn = model.smart_batching_collate
|
760 |
+
iterator = self.random_paragraphs
|
761 |
+
for step, batch in enumerate(iterator):
|
762 |
+
if step % 10 == 0:
|
763 |
+
logging.info('Batch {}/{}'.format(step, len(iterator)))
|
764 |
+
features, label_ids = batch_to_device(batch, self.device)
|
765 |
+
with torch.no_grad():
|
766 |
+
embeddings = model(features[0])['sentence_embedding'].to("cpu").numpy()
|
767 |
+
for emb in embeddings:
|
768 |
+
vindex.add(emb)
|
769 |
+
|
770 |
+
vindex.build()
|
771 |
+
|
772 |
+
logging.info('Ranking evaluation...')
|
773 |
+
|
774 |
+
mrr = 1e-8
|
775 |
+
recall = {1: 0, 5: 0, 10: 0, 20: 0, 100: 0}
|
776 |
+
|
777 |
+
all_results, _ = vindex.search(embeddings1, k=100, probes=1024)
|
778 |
+
|
779 |
+
for i in range(total_examples):
|
780 |
+
results = all_results[i]
|
781 |
+
rank = 1
|
782 |
+
found = False
|
783 |
+
for r in results:
|
784 |
+
if r < len(evidence_features) and evidence_features[r] == evidence_features[i]:
|
785 |
+
mrr += 1 / rank
|
786 |
+
found = True
|
787 |
+
break
|
788 |
+
rank += 1
|
789 |
+
|
790 |
+
for topk, count in recall.items():
|
791 |
+
if rank <= topk and found:
|
792 |
+
recall[topk] += 1
|
793 |
+
mrr /= total_examples
|
794 |
+
for topk, count in recall.items():
|
795 |
+
recall[topk] /= total_examples
|
796 |
+
logging.info('recall@{} : {}'.format(topk, recall[topk]))
|
797 |
+
|
798 |
+
logging.info('mrr@100 : {}'.format(mrr))
|
799 |
+
|
800 |
+
if output_path is not None:
|
801 |
+
f = open(output_path + '/stats.csv', 'a+')
|
802 |
+
f.write('{},{},{},{},{},{}\n'.format(recall[1], recall[5], recall[10], recall[20], recall[100], mrr))
|
803 |
+
f.close()
|
804 |
+
|
805 |
+
return mrr
|
806 |
+
|
807 |
+
|
808 |
+
if __name__ == "__main__":
|
809 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
810 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
811 |
+
level=logging.CRITICAL,
|
812 |
+
handlers=[LoggingHandler()])
|
813 |
+
|
814 |
+
# Read the dataset
|
815 |
+
model_name = 'distilroberta-base'
|
816 |
+
batch_size = 384
|
817 |
+
model_save_path = 'weights/encoder/qrbert-multitask-distil'
|
818 |
+
num_epochs = 1000
|
819 |
+
|
820 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
821 |
+
|
822 |
+
# Use BERT for mapping tokens to embeddings
|
823 |
+
word_embedding_model = BERT(model_name, max_seq_length=256, do_lower_case=False)
|
824 |
+
|
825 |
+
# Apply mean pooling to get one fixed sized sentence vector
|
826 |
+
pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension(),
|
827 |
+
pooling_mode_mean_tokens=True,
|
828 |
+
pooling_mode_cls_token=False,
|
829 |
+
pooling_mode_max_tokens=False)
|
830 |
+
|
831 |
+
if torch.cuda.device_count() > 1:
|
832 |
+
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
833 |
+
|
834 |
+
word_embedding_model = DataParallel(word_embedding_model)
|
835 |
+
word_embedding_model.to(device)
|
836 |
+
|
837 |
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], parallel=True)
|
838 |
+
|
839 |
+
#model = SentenceTransformer('models/weights/encoder/sbert-nli-fnli-qar2-768', parallel=True)
|
840 |
+
|
841 |
+
logging.info("Reading dev dataset")
|
842 |
+
dev_data = SentencesDataset(get_retrieval_examples(filename='../data/retrieval/dev.jsonl',
|
843 |
+
no_statements=False,
|
844 |
+
), model=model)
|
845 |
+
dev_dataloader = DataLoader(dev_data, shuffle=True, batch_size=1024)
|
846 |
+
evaluator = RankingEvaluator(dev_dataloader)
|
847 |
+
#model.best_score = model.evaluate(evaluator)
|
848 |
+
|
849 |
+
train_data = SentencesDataset(get_retrieval_examples(filename='../data/retrieval/train.jsonl',
|
850 |
+
no_statements=False),
|
851 |
+
model=model)
|
852 |
+
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
|
853 |
+
train_loss = MultipleNegativesRankingLoss(model=model)
|
854 |
+
|
855 |
+
for i in range(num_epochs):
|
856 |
+
logging.info("Epoch {}".format(i))
|
857 |
+
# Train the model
|
858 |
+
model.fit(train_objectives=[(train_dataloader, train_loss)],
|
859 |
+
evaluator=evaluator,
|
860 |
+
epochs=1,
|
861 |
+
acc_steps=1,
|
862 |
+
evaluation_steps=1000,
|
863 |
+
warmup_steps=1000,
|
864 |
+
output_path=model_save_path,
|
865 |
+
optimizer_params={'lr': 1e-6, 'eps': 1e-6, 'correct_bias': False}
|
866 |
+
)
|
867 |
+
torch.cuda.empty_cache()
|
models/tokenization.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def get_segments(sentence):
|
5 |
+
sentence_segments = []
|
6 |
+
temp = []
|
7 |
+
i = 0
|
8 |
+
for token in sentence.split(" "):
|
9 |
+
temp.append(i)
|
10 |
+
if token == "[SEP]":
|
11 |
+
i += 1
|
12 |
+
sentence_segments.append(temp)
|
13 |
+
return sentence_segments
|
14 |
+
|
15 |
+
|
16 |
+
def tokenize(text, max_length, tokenizer, second_text=None):
|
17 |
+
if second_text is None:
|
18 |
+
sentence = "[CLS] " + " ".join(tokenizer.tokenize(text.replace('[^\w\s]+|\n', ''))[:max_length-2]) + " [SEP]"
|
19 |
+
else:
|
20 |
+
text = tokenizer.tokenize(text.replace('[^\w\s]+|\n', ''))
|
21 |
+
second_text = tokenizer.tokenize(second_text.replace('[^\w\s]+|\n', ''))
|
22 |
+
while len(text) + len(second_text) > max_length - 3:
|
23 |
+
if len(text) > len(second_text):
|
24 |
+
text.pop()
|
25 |
+
else:
|
26 |
+
second_text.pop()
|
27 |
+
sentence = "[CLS] " + " ".join(text) + " [SEP] " + " ".join(second_text) + " [SEP]"
|
28 |
+
|
29 |
+
# generate masks
|
30 |
+
# bert requires a mask for the words which are padded.
|
31 |
+
# Say for example, maxlen is 100, sentence size is 90. then, [1]*90 + [0]*[100-90]
|
32 |
+
sentence_mask = [1] * len(sentence.split(" ")) + [0] * (max_length - len(sentence.split(" ")))
|
33 |
+
|
34 |
+
# generate input ids
|
35 |
+
# if less than max length provided then the words are padded
|
36 |
+
if len(sentence.split(" ")) != max_length:
|
37 |
+
sentence_padded = sentence + " [PAD]" * (max_length - len(sentence.split(" ")))
|
38 |
+
else:
|
39 |
+
sentence_padded = sentence
|
40 |
+
|
41 |
+
sentence_converted = tokenizer.convert_tokens_to_ids(sentence_padded.split(" "))
|
42 |
+
|
43 |
+
# generate segments
|
44 |
+
# for each separation [SEP], a new segment is converted
|
45 |
+
sentence_segment = get_segments(sentence_padded)
|
46 |
+
|
47 |
+
# convert list into tensor integer arrays and return it
|
48 |
+
# return sentences_converted,sentences_segment, sentences_mask
|
49 |
+
"""
|
50 |
+
return [tf.cast(sentence_converted, tf.int32),
|
51 |
+
tf.cast(sentence_segment, tf.int32),
|
52 |
+
tf.cast(sentence_mask, tf.int32)]
|
53 |
+
"""
|
54 |
+
return [np.asarray(sentence_converted, dtype=np.int32).squeeze(),
|
55 |
+
np.asarray(sentence_segment, dtype=np.int32).squeeze(),
|
56 |
+
np.asarray(sentence_mask, dtype=np.int32).squeeze()]
|
models/vector_index.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import math
|
3 |
+
import pickle
|
4 |
+
import h5py
|
5 |
+
import faiss
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import gpustat
|
9 |
+
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
# see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5
|
13 |
+
|
14 |
+
def get_phi(xb):
|
15 |
+
return (xb ** 2).sum(1).max()
|
16 |
+
|
17 |
+
|
18 |
+
def augment_xb(xb, phi=None):
|
19 |
+
norms = (xb ** 2).sum(1)
|
20 |
+
if phi is None:
|
21 |
+
phi = norms.max()
|
22 |
+
extracol = np.sqrt(phi - norms)
|
23 |
+
return np.hstack((xb, extracol.reshape(-1, 1)))
|
24 |
+
|
25 |
+
|
26 |
+
def augment_xq(xq):
|
27 |
+
extracol = np.zeros(len(xq), dtype='float32')
|
28 |
+
return np.hstack((xq, extracol.reshape(-1, 1)))
|
29 |
+
|
30 |
+
|
31 |
+
class VectorIndex:
|
32 |
+
def __init__(self, d):
|
33 |
+
self.d = d
|
34 |
+
self.vectors = []
|
35 |
+
self.index = None
|
36 |
+
|
37 |
+
def add(self, v):
|
38 |
+
self.vectors.append(v)
|
39 |
+
|
40 |
+
def add_vectors(self, vs):
|
41 |
+
logging.info('Adding vectors to index...')
|
42 |
+
self.index.add(vs)
|
43 |
+
|
44 |
+
|
45 |
+
def build(self, use_gpu=False):
|
46 |
+
self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors
|
47 |
+
|
48 |
+
faiss.normalize_L2(self.vectors)
|
49 |
+
|
50 |
+
#self.vectors = augment_xq(self.vectors)
|
51 |
+
|
52 |
+
logging.info('Indexing {} vectors'.format(self.vectors.shape[0]))
|
53 |
+
|
54 |
+
if self.vectors.shape[0] > 50000:
|
55 |
+
num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2)))))
|
56 |
+
|
57 |
+
logging.info('Using {} centroids'.format(num_centroids))
|
58 |
+
|
59 |
+
self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids))
|
60 |
+
|
61 |
+
ngpu = faiss.get_num_gpus()
|
62 |
+
if ngpu > 0 and use_gpu:
|
63 |
+
logging.info('Using {} GPUs'.format(ngpu))
|
64 |
+
|
65 |
+
index_ivf = faiss.extract_index_ivf(self.index)
|
66 |
+
gpustat.print_gpustat()
|
67 |
+
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d))
|
68 |
+
index_ivf.clustering_index = clustering_index
|
69 |
+
|
70 |
+
logging.info('Training index...')
|
71 |
+
|
72 |
+
self.index.train(self.vectors)
|
73 |
+
else:
|
74 |
+
self.index = faiss.IndexFlatL2(self.d)
|
75 |
+
if faiss.get_num_gpus() > 0 and use_gpu:
|
76 |
+
gpustat.print_gpustat()
|
77 |
+
self.index = faiss.index_cpu_to_all_gpus(self.index)
|
78 |
+
|
79 |
+
def load(self, path):
|
80 |
+
self.index = faiss.read_index(path)
|
81 |
+
|
82 |
+
def save(self, path):
|
83 |
+
gpustat.print_gpustat()
|
84 |
+
faiss.write_index(faiss.index_gpu_to_cpu(self.index), path)
|
85 |
+
|
86 |
+
def save_vectors(self, path):
|
87 |
+
#pickle.dump(self.vectors, open(path, 'wb'), protocol=4)
|
88 |
+
f = h5py.File(path, 'w')
|
89 |
+
dset = f.create_dataset('data', data=self.vectors)
|
90 |
+
f.close()
|
91 |
+
|
92 |
+
def search(self, vectors, k=1, probes=1):
|
93 |
+
if not isinstance(vectors, np.ndarray):
|
94 |
+
vectors = np.array(vectors)
|
95 |
+
#faiss.normalize_L2(vectors)
|
96 |
+
try:
|
97 |
+
self.index.nprobe = probes
|
98 |
+
except:
|
99 |
+
pass
|
100 |
+
distances, ids = self.index.search(vectors, k)
|
101 |
+
similarities = [(2-d)/2 for d in distances]
|
102 |
+
return ids, similarities
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm==4.41.0
|
2 |
+
transformers==3.0.2
|
3 |
+
faiss_gpu==1.6.3
|
4 |
+
syllables==0.1.0
|
5 |
+
scipy==1.3.1
|
6 |
+
nltk==3.2.5
|
7 |
+
lxml==4.6.3
|
8 |
+
requests==2.22.0
|
9 |
+
Flask_Cors==3.0.8
|
10 |
+
gpustat==0.6.0
|
11 |
+
h5py==2.10.0
|
12 |
+
packaging==20.3
|
13 |
+
pandas==1.0.3
|
14 |
+
spacy==3.0.3
|
15 |
+
sentence-transformers==0.3.0
|
16 |
+
numpy==1.17.4
|
17 |
+
matplotlib==3.3.1
|
18 |
+
filelock==3.0.12
|
19 |
+
boilerpipe3==1.1
|
20 |
+
Unidecode==1.1.1
|
21 |
+
Flask==2.0.0
|
22 |
+
apex
|
23 |
+
beautifulsoup4==4.9.3
|
24 |
+
python_dateutil==2.8.1
|
25 |
+
scikit_learn==0.24.1
|
26 |
+
tantivy==0.13.2
|
27 |
+
huggingface_hub==0.16.4
|
28 |
+
torch==1.6.0+cu101
|
29 |
+
torchvision==0.7.0+cu101
|
30 |
+
gradio==3.34.0
|
utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from html.parser import HTMLParser
|
2 |
+
|
3 |
+
|
4 |
+
class WebParser(HTMLParser):
|
5 |
+
"""
|
6 |
+
A class for converting the tagged html to formats that can be used by a ML model
|
7 |
+
"""
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.block_tags = {
|
11 |
+
'div', 'p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'
|
12 |
+
}
|
13 |
+
self.inline_tags = {
|
14 |
+
'', 'a', 'b', 'main', 'span', 'em', 'strong', 'br'
|
15 |
+
}
|
16 |
+
self.allowed_tags = {'div', 'p', '', 'a', 'b', 'main', 'span', 'em', 'strong', 'br'}
|
17 |
+
|
18 |
+
self.opened_tags = []
|
19 |
+
self.block_content = ''
|
20 |
+
self.blocks = []
|
21 |
+
|
22 |
+
def get_last_opened_tag(self):
|
23 |
+
"""
|
24 |
+
Gets the last visited tag
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
if len(self.opened_tags) > 0:
|
28 |
+
return self.opened_tags[len(self.opened_tags) - 1]
|
29 |
+
return ''
|
30 |
+
|
31 |
+
def error(self, message):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def handle_starttag(self, tag, attrs):
|
35 |
+
"""
|
36 |
+
Handles the start tag of an HTML node in the tree
|
37 |
+
:param tag: the HTML tag
|
38 |
+
:param attrs: the tag attributes
|
39 |
+
:return:
|
40 |
+
"""
|
41 |
+
self.opened_tags.append(tag)
|
42 |
+
|
43 |
+
def handle_endtag(self, tag):
|
44 |
+
"""
|
45 |
+
Handles the end tag of an HTML node in the tree
|
46 |
+
:param tag: the HTML tag
|
47 |
+
:return:
|
48 |
+
"""
|
49 |
+
if tag in self.block_tags:
|
50 |
+
self.block_content = self.block_content.strip()
|
51 |
+
if len(self.block_content) > 0:
|
52 |
+
#if not self.block_content.endswith('.'): self.block_content += '.'
|
53 |
+
self.blocks.append(self.block_content)
|
54 |
+
self.block_content = ''
|
55 |
+
if len(self.opened_tags) > 0:
|
56 |
+
self.opened_tags.pop()
|
57 |
+
|
58 |
+
def handle_data(self, data):
|
59 |
+
"""
|
60 |
+
Handles a text HTML node in the tree
|
61 |
+
:param data: the text node
|
62 |
+
:return:
|
63 |
+
"""
|
64 |
+
last_opened_tag = self.get_last_opened_tag()
|
65 |
+
if last_opened_tag in self.allowed_tags:
|
66 |
+
data = data.replace(' ', ' ').strip()
|
67 |
+
if data != '':
|
68 |
+
self.block_content += data + ' '
|
69 |
+
|
70 |
+
def get_blocks(self):
|
71 |
+
return self.blocks
|
web_search.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
from pprint import pprint
|
4 |
+
from urllib.request import Request, urlopen
|
5 |
+
from html.parser import HTMLParser
|
6 |
+
from bs4 import BeautifulSoup
|
7 |
+
from urllib.parse import quote_plus
|
8 |
+
|
9 |
+
|
10 |
+
def bing_search(query: str, pages_number=1) -> list:
|
11 |
+
"""
|
12 |
+
Gets web results from Bing
|
13 |
+
:param query: query to search
|
14 |
+
:param pages_number: number of search pages to scrape
|
15 |
+
:return: a list of links in ranked order
|
16 |
+
"""
|
17 |
+
urls = []
|
18 |
+
for page in range(pages_number):
|
19 |
+
first = page * 10 + 1
|
20 |
+
address = "https://www.bing.com/search?q=" + quote_plus(query) + '&first=' + str(first)
|
21 |
+
data = get_html(address)
|
22 |
+
soup = BeautifulSoup(data, 'lxml')
|
23 |
+
links = soup.findAll('li', {'class': 'b_algo'})
|
24 |
+
urls.extend([link.find('h2').find('a')['href'] for link in links])
|
25 |
+
|
26 |
+
return urls
|
27 |
+
|
28 |
+
|
29 |
+
def duckduckgo_search(query: str, pages=1):
|
30 |
+
urls = []
|
31 |
+
start_index = 0
|
32 |
+
for page in range(pages):
|
33 |
+
address = "https://duckduckgo.com/html/?kl=en-us&q={}&s={}".format(quote_plus(query), start_index)
|
34 |
+
data = get_html(address)
|
35 |
+
soup = BeautifulSoup(data, 'lxml')
|
36 |
+
links = soup.findAll('a', {'class': 'result__a'})
|
37 |
+
urls.extend([link['href'] for link in links])
|
38 |
+
start_index = len(urls)
|
39 |
+
|
40 |
+
return urls
|
41 |
+
|
42 |
+
|
43 |
+
def news_search(query: str, pages=1):
|
44 |
+
urls = []
|
45 |
+
for page in range(pages):
|
46 |
+
api_url = f'https://newslookup.com/results?l=2&q={quote_plus(query)}&dp=&mt=-1&mkt=0&mtx=0&mktx=0&s=&groupby=no&cat=-1&from=&fmt=&tp=720&ps=50&ovs=&page={page}'
|
47 |
+
data = get_html(api_url)
|
48 |
+
soup = BeautifulSoup(data, 'lxml')
|
49 |
+
links = soup.findAll('a', {'class': 'title'})
|
50 |
+
urls.extend([link['href'] for link in links])
|
51 |
+
return urls
|
52 |
+
|
53 |
+
|
54 |
+
def get_html(url: str) -> str:
|
55 |
+
"""
|
56 |
+
Downloads the html source code of a webpage
|
57 |
+
:param url:
|
58 |
+
:return: html source code
|
59 |
+
"""
|
60 |
+
try:
|
61 |
+
custom_user_agent = "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0"
|
62 |
+
req = Request(url, headers={"User-Agent": custom_user_agent})
|
63 |
+
page = urlopen(req, timeout=3)
|
64 |
+
return str(page.read())
|
65 |
+
except:
|
66 |
+
return ''
|
67 |
+
|
68 |
+
|
69 |
+
class WebParser(HTMLParser):
|
70 |
+
"""
|
71 |
+
A class for converting the tagged html to formats that can be used by a ML model
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self):
|
75 |
+
super().__init__()
|
76 |
+
self.block_tags = {
|
77 |
+
'div', 'p'
|
78 |
+
}
|
79 |
+
self.inline_tags = {
|
80 |
+
'', 'a', 'b', 'tr', 'main', 'span', 'time', 'td',
|
81 |
+
'sup', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'em', 'strong', 'br'
|
82 |
+
}
|
83 |
+
self.allowed_tags = self.block_tags.union(self.inline_tags)
|
84 |
+
self.opened_tags = []
|
85 |
+
self.block_content = ''
|
86 |
+
self.blocks = []
|
87 |
+
|
88 |
+
def get_last_opened_tag(self):
|
89 |
+
"""
|
90 |
+
Gets the last visited tag
|
91 |
+
:return:
|
92 |
+
"""
|
93 |
+
if len(self.opened_tags) > 0:
|
94 |
+
return self.opened_tags[len(self.opened_tags) - 1]
|
95 |
+
return ''
|
96 |
+
|
97 |
+
def error(self, message):
|
98 |
+
pass
|
99 |
+
|
100 |
+
def handle_starttag(self, tag, attrs):
|
101 |
+
"""
|
102 |
+
Handles the start tag of an HTML node in the tree
|
103 |
+
:param tag: the HTML tag
|
104 |
+
:param attrs: the tag attributes
|
105 |
+
:return:
|
106 |
+
"""
|
107 |
+
self.opened_tags.append(tag)
|
108 |
+
if tag in self.block_tags:
|
109 |
+
self.block_content = self.block_content.strip()
|
110 |
+
if len(self.block_content) > 0:
|
111 |
+
if not self.block_content.endswith('.'):
|
112 |
+
self.block_content += '.'
|
113 |
+
self.block_content = self.block_content.replace('\\n', ' ').replace('\\r', ' ')
|
114 |
+
self.block_content = re.sub("\s\s+", " ", self.block_content)
|
115 |
+
self.blocks.append(self.block_content)
|
116 |
+
self.block_content = ''
|
117 |
+
|
118 |
+
def handle_endtag(self, tag):
|
119 |
+
"""
|
120 |
+
Handles the end tag of an HTML node in the tree
|
121 |
+
:param tag: the HTML tag
|
122 |
+
:return:
|
123 |
+
"""
|
124 |
+
if len(self.opened_tags) > 0:
|
125 |
+
self.opened_tags.pop()
|
126 |
+
|
127 |
+
def handle_data(self, data):
|
128 |
+
"""
|
129 |
+
Handles a text HTML node in the tree
|
130 |
+
:param data: the text node
|
131 |
+
:return:
|
132 |
+
"""
|
133 |
+
last_opened_tag = self.get_last_opened_tag()
|
134 |
+
if last_opened_tag in self.allowed_tags:
|
135 |
+
data = data.replace(' ', ' ').strip()
|
136 |
+
if data != '':
|
137 |
+
self.block_content += data + ' '
|
138 |
+
|
139 |
+
def get_text(self):
|
140 |
+
return "\n\n".join(self.blocks)
|