anabmaulana commited on
Commit
592f71e
·
0 Parent(s):
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repo is modified from [Quin+](https://github.com/algoprog/Quin).
2
+
3
+ To set up:
4
+ 1. Copy `../models` folder into this folder
5
+ 2. Follow instructions in the original repository to download the models.
6
+ 3. Install requirements. The requirements.txt file in the original repo were for Python 3.7, but the requirements in this repo are based on Python 3.8. Choose according to the Python version available on your system.
7
+ 4. Run
8
+ ```shell
9
+ python quin_web.py
10
+ ```
11
+
12
+ To check if the backend is running correctly, set up frontend and enter a query there. By default at `INFO` lever, the logger will record each incoming query.
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import nltk
6
+ import numpy
7
+ import os
8
+ import re
9
+ import requests
10
+ import torch
11
+ import urllib
12
+ import urllib.parse as urlparse
13
+
14
+ from bs4 import BeautifulSoup
15
+ from flask import request, Flask, jsonify
16
+ from flask_cors import CORS
17
+ from huggingface_hub import hf_hub_download, snapshot_download, login
18
+ from models.nli import NLI
19
+ from models.qa_ranker import PassageRanker
20
+ from models.text_encoder import SentenceTransformer
21
+ from multiprocessing.pool import ThreadPool
22
+ from scipy import spatial
23
+ from scipy.special import softmax
24
+ from time import sleep
25
+ from urllib.parse import parse_qs
26
+ from web_search import get_html, WebParser, duckduckgo_search
27
+
28
+
29
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
30
+ API_URLS = ['https://letscheck.nus.edu.sg/quin/']
31
+ logging.getLogger().setLevel(logging.INFO)
32
+
33
+
34
+ def is_question(query):
35
+ if re.match(r'^(who|when|what|why|which|whose|is|are|was|were|do|does|did|how)', query) or query.endswith('?'):
36
+ return True
37
+ pos_tags = nltk.pos_tag(nltk.word_tokenize(query))
38
+ for tag in pos_tags:
39
+ if tag[1].startswith('VB'):
40
+ return False
41
+ return True
42
+
43
+
44
+ def clear_cache():
45
+ for url in API_URLS:
46
+ print(f"Clearing cache from {url}")
47
+ r = requests.post(f'{url}/clear_cache?secret=123')
48
+ print(r.status_code)
49
+ print(r)
50
+
51
+
52
+ class Quin:
53
+ def __init__(self):
54
+ nltk.download('punkt')
55
+ self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
56
+ self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp'])
57
+ self.app = Flask(__name__)
58
+ CORS(self.app)
59
+
60
+ snapshot_download(repo_id="anabmaulana/qrbert-multitask", local_dir="cache/qrbert-multitask")
61
+ hf_hub_download(repo_id="anabmaulana/quin-passage-ranker", filename="passage_ranker.state_dict", local_dir="cache/quin-passage-ranker")
62
+ hf_hub_download(repo_id="anabmaulana/quin-nli", filename="nli.state_dict", local_dir="cache/quin-nli")
63
+
64
+ # torch.cuda.is_available = lambda: False
65
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
66
+
67
+ self.text_embedding_model = SentenceTransformer('cache/qrbert-multitask',
68
+ device=device,
69
+ parallel=True)
70
+ self.passage_ranking_model = PassageRanker(
71
+ model_path='cache/quin-passage-ranker/passage_ranker.state_dict',
72
+ gpu=torch.cuda.is_available(),
73
+ batch_size=16)
74
+ self.passage_ranking_model.eval()
75
+ self.nli_model = NLI('cache/quin-nli/nli.state_dict',
76
+ batch_size=64,
77
+ device=device,
78
+ parallel=True)
79
+ self.nli_model.eval()
80
+ logging.info('Initialized!')
81
+
82
+ def extract_snippets(self, text, sentences_per_snippet=4):
83
+ sentences = self.sent_tokenizer.tokenize(text)
84
+ snippets = []
85
+ i = 0
86
+ last_index = 0
87
+ while i < len(sentences):
88
+ snippet = ' '.join(sentences[i:i + sentences_per_snippet])
89
+ if len(snippet.split(' ')) > 4:
90
+ snippets.append(snippet)
91
+ last_index = i + sentences_per_snippet
92
+ i += sentences_per_snippet
93
+ if last_index < len(sentences):
94
+ snippet = ' '.join(sentences[last_index:])
95
+ if len(snippet.split(' ')) > 4:
96
+ snippets.append(snippet)
97
+ return snippets
98
+
99
+ def search_web_evidence(self, query, limit=10):
100
+ logging.info('searching the web...')
101
+ urls = duckduckgo_search(query, pages=1)[:limit]
102
+ logging.info('downloading {} web pages...'.format(len(urls)))
103
+ search_results = []
104
+ # Move this out of the function, since it is used a few times
105
+ query_is_question = is_question(query)
106
+
107
+ # Local methods for multithreading
108
+ def download(url):
109
+ nonlocal search_results
110
+ data = get_html(url)
111
+ soup = BeautifulSoup(data, features='lxml')
112
+ title = soup.title.string
113
+ w = WebParser()
114
+ w.feed(data)
115
+ new_snippets = sum([self.extract_snippets(b) for b in w.blocks if b.count(' ') > 20], [])
116
+ new_snippets = [{'snippet': snippet, 'url': url, 'title': title} for snippet in new_snippets]
117
+ search_results += new_snippets
118
+
119
+ def timeout_download(arg):
120
+ pool = ThreadPool(1)
121
+ try:
122
+ pool.apply_async(download, [arg]).get(timeout=2)
123
+ except:
124
+ pass
125
+ pool.close()
126
+ pool.join()
127
+
128
+ p = ThreadPool(32)
129
+ p.map(timeout_download, urls)
130
+ p.close()
131
+ p.join()
132
+
133
+ if query_is_question:
134
+ logging.info('re-ranking...')
135
+ snippets = [s['snippet'] for s in search_results]
136
+ qa_pairs = [(query, snippet) for snippet in snippets]
137
+ _, probs = self.passage_ranking_model(qa_pairs)
138
+ probs = [softmax(p)[1] for p in probs]
139
+ filtered_results = []
140
+ for i in range(len(search_results)):
141
+ if probs[i] > 0.35:
142
+ search_results[i]['score'] = str(probs[i])
143
+ filtered_results.append(search_results[i])
144
+ search_results = filtered_results
145
+ search_results = sorted(search_results, key=lambda x: float(x['score']), reverse=True)
146
+
147
+ # Split into a different function below
148
+ # highlight_results(self, search_results)
149
+ # highlight most relevant sentences
150
+ logging.info('highlighting...')
151
+ results_sentences = []
152
+ sentences_texts = []
153
+ sentences_vectors = {}
154
+ for i, r in enumerate(search_results):
155
+ sentences = self.sent_tokenizer.tokenize(r['snippet'])
156
+ sentences = [s for s in sentences if len(s.split(' ')) > 4]
157
+ sentences_texts.extend(sentences)
158
+ results_sentences.append(sentences)
159
+
160
+ vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128)
161
+ for i, v in enumerate(vectors):
162
+ sentences_vectors[sentences_texts[i]] = v
163
+
164
+ query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0]
165
+ for i, sentences in enumerate(results_sentences):
166
+ best_sentences = set()
167
+ evidence_sentences = []
168
+ for sentence in sentences:
169
+ sentence_vector = sentences_vectors[sentence]
170
+ score = 1 - spatial.distance.cosine(query_vector, sentence_vector)
171
+ if score > 0.91:
172
+ best_sentences.add(sentence)
173
+ evidence_sentences.append(sentence)
174
+ if len(evidence_sentences) > 0:
175
+ search_results[i]['evidence'] = ' '.join(evidence_sentences)
176
+ search_results[i]['snippet'] = \
177
+ ' '.join([s if s not in best_sentences else '<b>{}</b>'.format(s) for s in sentences])
178
+
179
+ search_results = [s for s in search_results if 'evidence' in s]
180
+
181
+ # Split into a different function below
182
+ # entailment_classification(self, search_results)
183
+ if not query_is_question:
184
+ logging.info('entailment classification...')
185
+ es_pairs = []
186
+ for result in search_results:
187
+ evidence = result['evidence']
188
+ es_pairs.append((evidence, query))
189
+ try:
190
+ labels, probs = self.nli_model(es_pairs)
191
+ logging.info(str(labels))
192
+ for i in range(len(labels)):
193
+ confidence = numpy.exp(numpy.max(probs[i]))
194
+ if confidence > 0.4:
195
+ search_results[i]['nli_class'] = labels[i]
196
+ else:
197
+ search_results[i]['nli_class'] = 'neutral'
198
+ search_results[i]['nli_confidence'] = str(confidence)
199
+ except Exception as e:
200
+ logging.warning('Error doing entailment classification')
201
+ logging.warning(str(e))
202
+ try:
203
+ search_results = [s for s in search_results if s['nli_class'] != 'neutral']
204
+ except:
205
+ pass
206
+
207
+ filtered_results = []
208
+ added_urls = set()
209
+ supporting = 0
210
+ refuting = 0
211
+ for r in search_results:
212
+ if r['url'] not in added_urls:
213
+ filtered_results.append(r)
214
+ added_urls.add(r['url'])
215
+ if 'nli_class' in r:
216
+ if r['nli_class'] == 'entailment':
217
+ supporting += 1
218
+ elif r['nli_class'] == 'contradiction':
219
+ refuting += 1
220
+ search_results = filtered_results
221
+
222
+ if supporting > 4 or refuting > 4:
223
+ if supporting / (refuting + 0.001) > 1.7:
224
+ veracity = 'Probably True'
225
+ elif refuting / (supporting + 0.001) > 1.7:
226
+ veracity = 'Probably False'
227
+ else:
228
+ veracity = '? Ambiguous'
229
+ else:
230
+ veracity = 'Not enough evidence'
231
+
232
+ search_results = search_results[:limit]
233
+
234
+ logging.info('done searching')
235
+
236
+ if query_is_question:
237
+ return {'type': 'question',
238
+ 'results': search_results}
239
+ else:
240
+ return {'type': 'statement',
241
+ 'supporting': supporting,
242
+ 'refuting': refuting,
243
+ 'veracity_rating': veracity,
244
+ 'results': search_results}
245
+
246
+ def build_endpoints(self):
247
+ @self.app.route('/search', methods=['POST', 'GET'])
248
+ def search_endpoint():
249
+ query = request.args.get('query').lower()
250
+ limit = request.args.get('limit') or 100
251
+ limit = min(int(limit), 100)
252
+ results = json.dumps(self.search_web_evidence(query, limit=limit), indent=4)
253
+ return results
254
+
255
+ def serve(self, port=12345):
256
+ self.build_endpoints()
257
+ self.app.run(host='0.0.0.0', port=port)
258
+
259
+
260
+
261
+ fchecker = Quin()
262
+
263
+ def predict(query, limit):
264
+ return fchecker.search_web_evidence(query, limit=limit)
265
+
266
+ demo = gr.Interface(
267
+ fn=predict,
268
+ inputs=[
269
+ gr.Textbox(label="Enter Query"), # string input
270
+ gr.Number(label="Enter Limit") # number input
271
+ ],
272
+ outputs=gr.JSON(label="Output")
273
+ )
274
+
275
+
276
+ if __name__ == '__main__':
277
+ demo.launch()
models/__init__.py ADDED
File without changes
models/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (157 Bytes). View file
 
models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (128 Bytes). View file
 
models/__pycache__/data_utils.cpython-37.pyc ADDED
Binary file (23.7 kB). View file
 
models/__pycache__/dense_retriever.cpython-37.pyc ADDED
Binary file (2.42 kB). View file
 
models/__pycache__/nli.cpython-37.pyc ADDED
Binary file (8.4 kB). View file
 
models/__pycache__/nli.cpython-39.pyc ADDED
Binary file (8.38 kB). View file
 
models/__pycache__/qa_ranker.cpython-37.pyc ADDED
Binary file (9.03 kB). View file
 
models/__pycache__/text_encoder.cpython-37.pyc ADDED
Binary file (29.5 kB). View file
 
models/__pycache__/tokenization.cpython-37.pyc ADDED
Binary file (1.23 kB). View file
 
models/__pycache__/vector_index.cpython-37.pyc ADDED
Binary file (3.42 kB). View file
 
models/data_utils.py ADDED
@@ -0,0 +1,1006 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import pickle
4
+
5
+ import logging
6
+ import re
7
+ from pprint import pprint
8
+
9
+ import pandas
10
+ import gzip
11
+ import os
12
+
13
+ import numpy as np
14
+
15
+ from random import randint, random
16
+ from tqdm import tqdm
17
+ from typing import List, Dict
18
+
19
+ from spacy.lang.en import English
20
+ # from tensorflow.keras.utils import Sequence
21
+ from transformers import RobertaTokenizer, BertTokenizer, BertModel, AutoTokenizer, AutoModel
22
+
23
+ # from utils import WebParser
24
+ from models.dense_retriever import DenseRetriever
25
+ from models.tokenization import tokenize
26
+
27
+ from typing import Union, List
28
+
29
+ from unidecode import unidecode
30
+
31
+ import spacy
32
+
33
+
34
+ # from utils import WebParser
35
+
36
+
37
+ class InputExample:
38
+ """
39
+ Structure for one input example with texts, the label and a unique id
40
+ """
41
+
42
+ def __init__(self, guid: str, texts: List[str], label: Union[int, float]):
43
+ """
44
+ Creates one InputExample with the given texts, guid and label
45
+ str.strip() is called on both texts.
46
+ :param guid
47
+ id for the example
48
+ :param texts
49
+ the texts for the example
50
+ :param label
51
+ the label for the example
52
+ """
53
+ self.guid = guid
54
+ self.texts = [text.strip() for text in texts]
55
+ self.label = label
56
+
57
+ def get_texts(self):
58
+ return self.texts
59
+
60
+ def get_label(self):
61
+ return self.label
62
+
63
+
64
+ class LoggingHandler(logging.Handler):
65
+ def __init__(self, level=logging.NOTSET):
66
+ super().__init__(level)
67
+
68
+ def emit(self, record):
69
+ try:
70
+ msg = self.format(record)
71
+ tqdm.write(msg)
72
+ self.flush()
73
+ except (KeyboardInterrupt, SystemExit):
74
+ raise
75
+ except:
76
+ self.handleError(record)
77
+
78
+
79
+ def get_examples(filename, max_examples=0):
80
+ examples = []
81
+ id = 0
82
+ with open(filename, encoding='utf8') as file:
83
+ for j, line in enumerate(file):
84
+ line = line.rstrip('\n')
85
+ sample = json.loads(line)
86
+ label = sample['label']
87
+ guid = "%s-%d" % (filename, id)
88
+ id += 1
89
+ if label == 'entailment':
90
+ label = 0
91
+ elif label == 'contradiction':
92
+ label = 1
93
+ else:
94
+ label = 2
95
+ examples.append(InputExample(guid=guid,
96
+ texts=[sample['s1'], sample['s2']],
97
+ label=label))
98
+ if 0 < max_examples <= len(examples):
99
+ break
100
+ return examples
101
+
102
+
103
+ def get_qa_examples(filename, max_examples=0, dev=False):
104
+ examples = []
105
+ id = 0
106
+ with open(filename, encoding='utf8') as file:
107
+ for j, line in enumerate(file):
108
+ line = line.rstrip('\n')
109
+ sample = json.loads(line)
110
+ label = sample['relevant']
111
+ guid = "%s-%d" % (filename, id)
112
+ id += 1
113
+ examples.append(InputExample(guid=guid,
114
+ texts=[sample['question'], sample['answer']],
115
+ label=label))
116
+ if not dev:
117
+ if label == 1:
118
+ for _ in range(13):
119
+ examples.append(InputExample(guid=guid,
120
+ texts=[sample['question'], sample['answer']],
121
+ label=label))
122
+ if 0 < max_examples <= len(examples):
123
+ break
124
+ return examples
125
+
126
+
127
+ def map_label(label):
128
+ labels = {"relevant": 0, "irrelevant": 1}
129
+ return labels[label.strip().lower()]
130
+
131
+
132
+ def get_qar_examples(filename, max_examples=0):
133
+ examples = []
134
+ id = 0
135
+ with open(filename, encoding='utf8') as file:
136
+ for j, line in enumerate(file):
137
+ line = line.rstrip('\n')
138
+ sample = json.loads(line)
139
+ guid = "%s-%d" % (filename, id)
140
+ id += 1
141
+ examples.append(InputExample(guid=guid,
142
+ texts=[sample['question'], sample['answer']],
143
+ label=1.0))
144
+ if 0 < max_examples <= len(examples):
145
+ break
146
+ return examples
147
+
148
+
149
+ def get_qar_artificial_examples():
150
+ examples = []
151
+ id = 0
152
+
153
+ print('Loading passages...')
154
+
155
+ passages = []
156
+ file = open('data/msmarco/collection.tsv', 'r', encoding='utf8')
157
+ while True:
158
+ line = file.readline()
159
+ if not line:
160
+ break
161
+ line = line.rstrip('\n').split('\t')
162
+ passages.append(line[1])
163
+
164
+ print('Loaded passages')
165
+
166
+ with open('data/qar/qar_artificial_queries.csv') as f:
167
+ for i, line in enumerate(f):
168
+ queries = line.rstrip('\n').split('|')
169
+ for query in queries:
170
+ guid = "%s-%d" % ('', id)
171
+ id += 1
172
+ examples.append(InputExample(guid=guid,
173
+ texts=[query, passages[i]],
174
+ label=1.0))
175
+ return examples
176
+
177
+
178
+ def get_single_examples(filename, max_examples=0):
179
+ examples = []
180
+ id = 0
181
+ with open(filename, encoding='utf8') as file:
182
+ for j, line in enumerate(file):
183
+ line = line.rstrip('\n')
184
+ sample = json.loads(line)
185
+ guid = "%s-%d" % (filename, id)
186
+ id += 1
187
+ examples.append(InputExample(guid=guid,
188
+ texts=[sample['text']],
189
+ label=1))
190
+ if 0 < max_examples <= len(examples):
191
+ break
192
+ return examples
193
+
194
+
195
+ def get_qnli_examples(filename, max_examples=0, no_contradictions=False, fever_only=False):
196
+ examples = []
197
+ id = 0
198
+ with open(filename, encoding='utf8') as file:
199
+ for j, line in enumerate(file):
200
+ line = line.rstrip('\n')
201
+ sample = json.loads(line)
202
+ label = sample['label']
203
+ if label == 'contradiction' and no_contradictions:
204
+ continue
205
+ if sample['evidence'] == '':
206
+ continue
207
+ if fever_only and sample['source'] != 'fever':
208
+ continue
209
+ guid = "%s-%d" % (filename, id)
210
+ id += 1
211
+
212
+ examples.append(InputExample(guid=guid,
213
+ texts=[sample['statement'].strip(), sample['evidence'].strip()],
214
+ label=1.0))
215
+ if 0 < max_examples <= len(examples):
216
+ break
217
+ return examples
218
+
219
+
220
+ def get_retrieval_examples(filename, negative_corpus='data/msmarco/collection.tsv', max_examples=0, no_statements=True,
221
+ encoder_model=None, negative_samples_num=4):
222
+ examples = []
223
+ queries = []
224
+ passages = []
225
+ negative_passages = []
226
+ id = 0
227
+ with open(filename, encoding='utf8') as file:
228
+ for j, line in enumerate(file):
229
+ line = line.rstrip('\n')
230
+ sample = json.loads(line)
231
+
232
+ if 'evidence' in sample and sample['evidence'] == '':
233
+ continue
234
+
235
+ guid = "%s-%d" % (filename, id)
236
+ id += 1
237
+
238
+ if sample['type'] == 'question':
239
+ query = sample['question']
240
+ passage = sample['answer']
241
+ else:
242
+ query = sample['statement']
243
+ passage = sample['evidence']
244
+
245
+ query = query.strip()
246
+ passage = passage.strip()
247
+
248
+ if sample['type'] == 'statement' and no_statements:
249
+ continue
250
+
251
+ queries.append(query)
252
+ passages.append(passage)
253
+
254
+ if sample['source'] == 'natural-questions':
255
+ negative_passages.append(passage)
256
+
257
+ if max_examples == len(passages):
258
+ break
259
+
260
+ if encoder_model is not None:
261
+ # Load MSMARCO passages
262
+ logging.info('Loading MSM passages...')
263
+ with open(negative_corpus) as file:
264
+ for line in file:
265
+ p = line.rstrip('\n').split('\t')[1]
266
+ negative_passages.append(p)
267
+
268
+ logging.info('Building ANN index...')
269
+ dense_retriever = DenseRetriever(model=encoder_model, batch_size=1024, use_gpu=True)
270
+ dense_retriever.create_index_from_documents(negative_passages)
271
+ results = dense_retriever.search(queries=queries, limit=100, probes=256)
272
+ negative_samples = [
273
+ [negative_passages[p[0]] for p in r if negative_passages[p[0]] != passages[i]][:negative_samples_num]
274
+ for i, r in enumerate(results)
275
+ ]
276
+ # print(queries[0])
277
+ # print(negative_samples[0][0])
278
+
279
+ for i in range(len(queries)):
280
+ texts = [queries[i], passages[i]] + negative_samples[i]
281
+ examples.append(InputExample(guid=guid,
282
+ texts=texts,
283
+ label=1.0))
284
+
285
+ else:
286
+ for i in range(len(queries)):
287
+ texts = [queries[i], passages[i]]
288
+ examples.append(InputExample(guid=guid,
289
+ texts=texts,
290
+ label=1.0))
291
+
292
+ return examples
293
+
294
+
295
+ def get_ict_examples(filename, max_examples=0):
296
+ examples = []
297
+ id = 0
298
+ with open(filename, encoding='utf8') as file:
299
+ for j, line in enumerate(file):
300
+ line = line.rstrip('\n')
301
+ sample = json.loads(line)
302
+ # label = sample['label']
303
+ guid = "%s-%d" % (filename, id)
304
+ id += 1
305
+ examples.append(InputExample(guid=guid,
306
+ texts=[sample['s1'].strip(), sample['s2'].strip()],
307
+ label=1.0))
308
+ if 0 < max_examples <= len(examples):
309
+ break
310
+ return examples
311
+
312
+
313
+ def preprocess_fever_dataset(path):
314
+ def replace_symbols(text):
315
+ return text.replace('-LRB-', '(') \
316
+ .replace('-RRB-', ')') \
317
+ .replace('-LSB-', '[') \
318
+ .replace('-RSB-', ']'). \
319
+ replace('-LCB-', '{'). \
320
+ replace('-RCB-', '}')
321
+
322
+ print('Loading wiki articles...')
323
+ articles = {}
324
+ for i in range(109):
325
+ wid = str(i + 1).zfill(3)
326
+ file = open(path + '/wiki-pages/wiki-' + wid + '.jsonl', 'r', encoding='utf8')
327
+ for line in file:
328
+ data = json.loads(line.rstrip('\n'))
329
+ lines = data['lines'].split('\n')
330
+ plines = []
331
+ for line in lines:
332
+ slines = line.split('\t')
333
+ if len(slines) > 1 and slines[0].isnumeric():
334
+ plines.append(slines[1])
335
+ else:
336
+ plines.append('')
337
+
338
+ lines = plines
339
+ data['id'] = unidecode(data['id'])
340
+ articles[data['id']] = lines
341
+
342
+ print('Preprocessing dataset...')
343
+ files = ['train', 'dev']
344
+ for file in files:
345
+ fo = open(path + '/facts_{}.jsonl'.format(file), 'w+', encoding='utf8')
346
+ with open(path + '/{}.jsonl'.format(file), encoding='utf8') as f:
347
+ for line in f:
348
+ data = json.loads(line.rstrip('\n'))
349
+ claim = data['claim']
350
+ label = data['label']
351
+ evidence = []
352
+ sents = []
353
+ if label != 'NOT ENOUGH INFO':
354
+ for evidence_list in data['evidence']:
355
+ evidence_sents = []
356
+ extra_left = []
357
+ extra_right = []
358
+ evidence_words = 0
359
+ for e in evidence_list:
360
+ e[2] = unidecode(e[2])
361
+ if e[2] in articles and len(articles[e[2]]) > e[3]:
362
+ eid = e[3]
363
+ article = articles[e[2]]
364
+ evidence_sents.append(replace_symbols(article[eid].replace(' ', ' ').strip()))
365
+ evidence_words += len(article[eid].split(' '))
366
+ left = []
367
+ for i in range(0, eid):
368
+ left.append(replace_symbols(article[i]))
369
+ extra_left.append(left)
370
+ right = []
371
+ for i in range(eid + 1, len(article)):
372
+ right.append(replace_symbols(article[i]))
373
+ extra_right.append(right)
374
+ else:
375
+ evidence_sents = []
376
+ break
377
+ evidence_text = []
378
+ for i in range(len(evidence_sents)):
379
+ for j in range(len(extra_left[i])):
380
+ xlen = len(extra_left[i][j].split(' '))
381
+ if evidence_words + xlen < 254:
382
+ evidence_text.append(extra_left[i][j])
383
+ evidence_words += xlen
384
+ evidence_text.append(evidence_sents[i])
385
+ for j in range(len(extra_right[i])):
386
+ xlen = len(extra_right[i][j].split(' '))
387
+ if evidence_words + xlen < 254:
388
+ evidence_text.append(extra_right[i][j])
389
+ evidence_words += xlen
390
+
391
+ if len(evidence_text) > 0:
392
+ evidence_text = unidecode(' '.join(evidence_text))
393
+ evidence_text = ' '.join(evidence_text.split()).strip()
394
+ evidence.append(evidence_text)
395
+ sents.append(' '.join(evidence_sents))
396
+
397
+ if label == 'NOT ENOUGH INFO' or len(evidence) > 0:
398
+ fo.write(json.dumps({
399
+ 'statement': unidecode(replace_symbols(claim)),
400
+ 'label': label,
401
+ 'evidence_long': list(set(evidence)),
402
+ 'evidence': list(set(sents))
403
+ }) + '\n')
404
+ fo.close()
405
+
406
+
407
+ def preprocess_squad_dataset(path, output):
408
+ # Read dataset
409
+ with open(path) as f:
410
+ dataset = json.load(f)
411
+
412
+ # Iterate and write question-answer pairs
413
+ with open(output, 'w+') as f:
414
+ for article in dataset['data']:
415
+ for paragraph in article['paragraphs']:
416
+ context = paragraph['context']
417
+ for qa in paragraph['qas']:
418
+ question = qa['question']
419
+ answers = [a['text'] for a in qa['answers']]
420
+ min_answer = ''
421
+ min_length = 1000
422
+ for answer in answers:
423
+ if len(answer) < min_length:
424
+ min_answer = answer
425
+ min_length = len(answer)
426
+ if min_answer == '':
427
+ continue
428
+ f.write(json.dumps({'question': question, 'answer': min_answer, 'evidence': context}))
429
+ f.write('\n')
430
+
431
+
432
+ def preprocess_nq_dataset():
433
+ def _clean_token(token):
434
+ return re.sub(u" ", "_", token["token"])
435
+
436
+ def get_text_blocks(html):
437
+ w = WebParser()
438
+ html = '<div>{}</div>'.format(html)
439
+ w.feed(html)
440
+ blocks = w.get_blocks()
441
+ return blocks
442
+
443
+ files = ['train']
444
+ for file in files:
445
+ of = open('../data/qa/nq/{}_all_h.jsonl'.format(file), 'w+')
446
+ with open('../data/qa/nq/{}.jsonl'.format(file), encoding='utf8') as f:
447
+ for line in tqdm(f):
448
+ data = json.loads(line.rstrip('\n'))
449
+ question = data['question_text']
450
+
451
+ if file == 'dev':
452
+ data['document_text'] = " ".join([_clean_token(t) for t in data["document_tokens"]])
453
+ doc_tokens = data['document_text'].replace('*', '').split(' ')
454
+
455
+ short_answer = ''
456
+ if len(data['annotations'][0]['short_answers']) > 0:
457
+ short_answer_info = data['annotations'][0]['short_answers'][0]
458
+ doc_tokens.insert(short_answer_info['start_token'], '*')
459
+ doc_tokens.insert(short_answer_info['end_token'] + 1, '*')
460
+ short_answer = ' '.join(
461
+ doc_tokens[short_answer_info['start_token'] + 1:short_answer_info['end_token'] + 1])
462
+ short_answer = ' '.join(get_text_blocks(short_answer))
463
+
464
+ long_answer_info = data['annotations'][0]['long_answer']
465
+ long_answer = ' '.join(doc_tokens[long_answer_info['start_token']:long_answer_info['end_token']])
466
+ long_answer = ' '.join(get_text_blocks(long_answer))
467
+
468
+ if long_answer == '':
469
+ continue
470
+
471
+ example = {
472
+ 'question': question,
473
+ 'answer': short_answer,
474
+ 'evidence': long_answer
475
+ }
476
+
477
+ of.write(json.dumps(example) + '\n')
478
+
479
+
480
+ def preprocess_msmarco():
481
+ files = ['train', 'dev']
482
+ for file in files:
483
+ fo = open('../data/qa/msmarco/{}_all.jsonl'.format(file), 'w+')
484
+ with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
485
+ for line in f:
486
+ data = json.loads(line.rstrip('\n'))
487
+ q = data['query']
488
+ a = data['answers'][0]
489
+ evidence = ''
490
+ for passage in data['passages']:
491
+ if passage['is_selected'] == 1:
492
+ evidence = passage['passage_text']
493
+ break
494
+
495
+ if '_' in q:
496
+ continue
497
+
498
+ example = {
499
+ 'question': q,
500
+ 'answer': a,
501
+ 'evidence': evidence
502
+ }
503
+
504
+ fo.write(json.dumps(example) + '\n')
505
+
506
+
507
+ def create_qa_ranking_dataset():
508
+ files = ['train', 'dev']
509
+ for file in files:
510
+ relevant = 0
511
+ irrelevant = 0
512
+ max_irrelevant = 10 if file == 'train' else 1
513
+ fo = open('../data/qa_ranking_large/{}.jsonl'.format(file), 'w+')
514
+ with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
515
+ for line in f:
516
+ data = json.loads(line.rstrip('\n'))
517
+ added_irrelevant = 0
518
+ for passage in data['passages']:
519
+ example = {
520
+ 'question': data['query'],
521
+ 'answer': passage['passage_text'],
522
+ 'relevant': passage['is_selected']
523
+ }
524
+ if passage['is_selected'] == 1 or added_irrelevant < max_irrelevant:
525
+ fo.write(json.dumps(example) + '\n')
526
+ if passage['is_selected'] == 1:
527
+ relevant += 1
528
+ else:
529
+ added_irrelevant += 1
530
+ irrelevant += 1
531
+
532
+ print(relevant, irrelevant)
533
+
534
+
535
+ def preprocess_qa_nli_dataset(path):
536
+ files = ['dev', 'train']
537
+ for file in files:
538
+ of = open(path + '/' + file + '_.tsv', 'w+')
539
+ with open(path + '/' + file + '.tsv', 'r') as f:
540
+ for line in f:
541
+ line = line.rstrip()
542
+ s = line.split('\t')
543
+ if s[2].endswith('.'):
544
+ s[2] = s[2][:len(s[2]) - 1] + '?'
545
+ if not s[2].endswith('?'):
546
+ s[2] += ' ?'
547
+ input_str = s[2] + ' ' + s[3]
548
+ output_str = s[4]
549
+ of.write(input_str + '\t' + output_str + '\n')
550
+ of.close()
551
+
552
+
553
+ def get_pair_input(tokenizer, sent1, sent2, max_len=256):
554
+ text = "[CLS] {} [SEP] {} [SEP]".format(sent1, sent2)
555
+
556
+ tokenized_text = tokenizer.tokenize(text)[:max_len]
557
+ indexed_tokens = tokenizer.encode(text)[:max_len]
558
+
559
+ segments_ids = []
560
+ sep_flag = False
561
+ for i in range(len(tokenized_text)):
562
+ if tokenized_text[i] == '[SEP]' and not sep_flag:
563
+ segments_ids.append(0)
564
+ sep_flag = True
565
+ elif sep_flag:
566
+ segments_ids.append(1)
567
+ else:
568
+ segments_ids.append(0)
569
+ return indexed_tokens, segments_ids
570
+
571
+
572
+ def build_batch(tokenizer, text_list, max_len=256):
573
+ token_id_list = []
574
+ segment_list = []
575
+ attention_masks = []
576
+ longest = -1
577
+
578
+ for pair in text_list:
579
+ sent1, sent2 = pair
580
+ ids, segs = get_pair_input(tokenizer, sent1, sent2, max_len=max_len)
581
+ if ids is None or segs is None:
582
+ continue
583
+ token_id_list.append(ids)
584
+ segment_list.append(segs)
585
+ attention_masks.append([1] * len(ids))
586
+ if len(ids) > longest:
587
+ longest = len(ids)
588
+
589
+ if len(token_id_list) == 0:
590
+ return None, None, None
591
+
592
+ # padding
593
+ assert (len(token_id_list) == len(segment_list))
594
+ for ii in range(len(token_id_list)):
595
+ token_id_list[ii] += [0] * (longest - len(token_id_list[ii]))
596
+ attention_masks[ii] += [1] * (longest - len(attention_masks[ii]))
597
+ segment_list[ii] += [1] * (longest - len(segment_list[ii]))
598
+
599
+ return token_id_list, segment_list, attention_masks
600
+
601
+
602
+ """
603
+ class QAEncoderBatchGenerator(Sequence):
604
+ def __init__(self, ids_to_paragraphs: List, ids_to_queries: Dict, qrels_train: List, tokenizer,
605
+ max_question_length=64, max_answer_length=512, batch_size=64):
606
+ np.random.seed(42)
607
+ self.ids_to_paragraphs = ids_to_paragraphs
608
+ self.ids_to_queries = ids_to_queries
609
+ self.qrels_train = qrels_train
610
+ self.tokenizer = tokenizer
611
+ self.max_question_length = max_question_length
612
+ self.max_answer_length = max_answer_length
613
+ self.batch_size = batch_size
614
+ self.on_epoch_end()
615
+
616
+ def __len__(self):
617
+ return int(np.floor(len(self.qrels_train) / self.batch_size))
618
+
619
+ def __getitem__(self, index):
620
+ indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
621
+
622
+ question_texts = []
623
+ answer_texts = []
624
+
625
+ for index in indexes:
626
+ query_id = self.qrels_train[index][0]
627
+ answer_id = self.qrels_train[index][1]
628
+ question_texts.append(self.ids_to_queries[query_id])
629
+ answer_texts.append(self.ids_to_paragraphs[answer_id])
630
+
631
+ question_inputs = get_inputs_text(texts=question_texts, tokenizer=self.tokenizer,
632
+ max_length=self.max_question_length)
633
+ answer_inputs = get_inputs_text(texts=answer_texts, tokenizer=self.tokenizer,
634
+ max_length=self.max_answer_length)
635
+
636
+ x = {
637
+ 'question_input_word_ids': question_inputs[0],
638
+ 'question_input_masks': question_inputs[1],
639
+ 'question_input_segments': question_inputs[2],
640
+ 'answer_input_word_ids': answer_inputs[0],
641
+ 'answer_input_masks': answer_inputs[1],
642
+ 'answer_input_segments': answer_inputs[2]
643
+ }
644
+ y = np.zeros((self.batch_size, self.batch_size))
645
+
646
+ return x, y
647
+
648
+ def on_epoch_end(self, logs=None):
649
+ self.indexes = np.arange(len(self.qrels_train))
650
+ np.random.shuffle(self.indexes)
651
+ """
652
+
653
+
654
+ def create_qa_encoder_pretraining_dataset(dataset_file, output_file, total_samples):
655
+ print('Loading paragraphs in memory...')
656
+
657
+ articles = []
658
+ file = open(dataset_file, 'r', encoding='utf8')
659
+ while True:
660
+ line = file.readline()
661
+ if not line:
662
+ break
663
+ line = line.rstrip('\n')
664
+ data = json.loads(line)
665
+ articles.append(data['paragraphs'])
666
+
667
+ print('Loaded paragraphs in memory')
668
+
669
+ o = open(output_file, 'w+')
670
+
671
+ added = 0
672
+ while added < total_samples:
673
+ index = randint(0, len(articles) - 1)
674
+ article = articles[index]
675
+ if len(article) < 2:
676
+ continue
677
+ if added % 2 == 0: # Body-First-Selection
678
+ paragraph = article[0]
679
+ random_question_sentence_id = randint(0, len(paragraph) - 1)
680
+ question = paragraph[random_question_sentence_id]
681
+ random_answer_paragraph_id = randint(1, len(article) - 1)
682
+ answer = ' '.join(article[random_answer_paragraph_id])
683
+ else: # Inverse-Cloze-Task
684
+ random_paragraph_id = randint(0, len(article) - 1)
685
+ paragraph = article[random_paragraph_id]
686
+ random_question_sentence_id = randint(0, len(paragraph) - 1)
687
+ question = paragraph[random_question_sentence_id]
688
+ choice = random()
689
+ if choice < 0.9:
690
+ answer = ' '.join([paragraph[j] for j in range(len(paragraph)) if j != random_question_sentence_id])
691
+ else:
692
+ answer = ' '.join(paragraph)
693
+
694
+ o.write(json.dumps({'s1': question, 's2': answer}) + '\n')
695
+
696
+ added += 1
697
+ if added % 10000 == 0:
698
+ print(added)
699
+
700
+ o.close()
701
+
702
+
703
+ def create_random_wiki_paragraphs(dataset_file, output_file, total_samples, max_length=510):
704
+ print('Loading paragraphs in memory...')
705
+
706
+ articles = []
707
+ file = open(dataset_file, 'r', encoding='utf8')
708
+ while True:
709
+ line = file.readline()
710
+ if not line:
711
+ break
712
+ line = line.rstrip('\n')
713
+ data = json.loads(line)
714
+ articles.append(data['paragraphs'])
715
+
716
+ print('Loaded paragraphs in memory')
717
+
718
+ o = open(output_file, 'w+')
719
+ added = 0
720
+ while added < total_samples:
721
+ index = randint(0, len(articles) - 1)
722
+ article = articles[index]
723
+ if len(article) < 1:
724
+ continue
725
+ random_paragraph_id = randint(0, len(article) - 1)
726
+ paragraph = article[random_paragraph_id]
727
+
728
+ p = []
729
+ words = 0
730
+ i = 0
731
+ while i < len(paragraph):
732
+ nwords = len(paragraph[i].split(' '))
733
+ if nwords + words >= max_length:
734
+ break
735
+ p.append(paragraph[i])
736
+ i += 1
737
+
738
+ o.write(json.dumps({'text': ' '.join(p)}) + '\n')
739
+ added += 1
740
+ if added % 10000 == 0:
741
+ print(added)
742
+ o.close()
743
+
744
+
745
+ def load_unsupervised_dataset(dataset_file):
746
+ print('Loading dataset...')
747
+ x = pickle.load(open(dataset_file, "rb"))
748
+ print('Done')
749
+ return x, len(x[0])
750
+
751
+
752
+ def create_entailment_dataset(dataset_file, tokenizer, output_file):
753
+ x = [[] for _ in range(6)]
754
+ y = []
755
+
756
+ with open(dataset_file, encoding='utf8') as file:
757
+ for j, line in enumerate(file):
758
+ if j % 10000 == 0:
759
+ print(j)
760
+ line = line.rstrip('\n')
761
+ sample = json.loads(line)
762
+ label = sample['label']
763
+ s1 = tokenize(text=sample['s1'],
764
+ max_length=256,
765
+ tokenizer=tokenizer)
766
+ s2 = tokenize(text=sample['s2'],
767
+ max_length=256,
768
+ tokenizer=tokenizer)
769
+ example = s1 + s2
770
+ for i in range(6):
771
+ x[i].append(example[i])
772
+ if label == 'entailment':
773
+ label = [1.0, 0.0, 0.0]
774
+ elif label == 'contradiction':
775
+ label = [0.0, 1.0, 0.0]
776
+ else:
777
+ label = [0.0, 0.0, 1.0]
778
+ y.append(np.asarray(label, dtype='float32'))
779
+
780
+ for i in range(6):
781
+ x[i] = np.asarray(x[i])
782
+
783
+ y = np.asarray(y)
784
+ data = [x, y]
785
+
786
+ pickle.dump(data, open(output_file, "wb"), protocol=4)
787
+
788
+
789
+ def load_supervised_dataset(dataset_file):
790
+ print('Loading dataset...')
791
+ d = pickle.load(open(dataset_file, "rb"))
792
+ print('Done')
793
+ return d[0], d[1]
794
+
795
+
796
+ def preprocess_sg():
797
+ fo = open('../data/qa/squad/dev_sg_sq.txt', 'w+')
798
+ f = open('../data/qa/squad/dev.jsonl', 'r')
799
+ for line in f:
800
+ d = json.loads(line.rstrip('\n'))
801
+ q = d['question']
802
+ a = d['answer']
803
+
804
+ if q.endswith('.'):
805
+ q = q[:len(q) - 1] + '?'
806
+ if not q.endswith('?'):
807
+ q += ' ?'
808
+
809
+ fo.write(q + ' ' + a + '\n')
810
+
811
+
812
+ def preprocess_msmarco_():
813
+ fo = open('../data/qa/msmarco/dev_sg.tsv', 'w+')
814
+ with open('../data/qa/msmarco/dev.jsonl', 'r') as f:
815
+ for line in f:
816
+ data = json.loads(line.rstrip('\n'))
817
+ if data['wellFormedAnswers'] == '[]' or data['query_type'] == 'DESCRIPTION':
818
+ continue
819
+ q = data['query']
820
+ a = data['answers'][0]
821
+ fa = data['wellFormedAnswers'][0]
822
+
823
+ if len(a) > len(q):
824
+ continue
825
+
826
+ if q.endswith('.'):
827
+ q = q[:len(q) - 1] + '?'
828
+ if not q.endswith('?'):
829
+ q += ' ?'
830
+
831
+ if '_' in q:
832
+ continue
833
+
834
+ fo.write('{} {}\t{}\n'.format(q, a, fa))
835
+
836
+
837
+ def create_qnli_nq():
838
+ files = ['train', 'dev']
839
+ for file in files:
840
+ o = open('../data/qnli/{}_nq.jsonl'.format(file), 'w+')
841
+
842
+ fs = open('../data/qnli/{}.txt'.format(file))
843
+ statements = [s.rstrip('\n') for s in fs]
844
+
845
+ fd = open('../data/qa/nq/{}_nli.jsonl'.format(file))
846
+ for i, line in enumerate(fd):
847
+ d = json.loads(line.rstrip('\n'))
848
+ o.write(json.dumps({
849
+ 'statement': statements[i],
850
+ 'question': d['question'],
851
+ 'answer': d['answer'],
852
+ 'evidence': d['evidence']
853
+ }) + '\n')
854
+
855
+
856
+ def create_qnli_ms():
857
+ files = ['train', 'dev']
858
+ for file in files:
859
+ o = open('../data/qnli/{}_ms.jsonl'.format(file), 'w+')
860
+ with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
861
+ for line in f:
862
+ data = json.loads(line.rstrip('\n'))
863
+ if data['wellFormedAnswers'] == '[]' or data['query_type'] == 'DESCRIPTION':
864
+ continue
865
+ q = data['query']
866
+ a = data['answers'][0]
867
+ fa = data['wellFormedAnswers'][0]
868
+
869
+ evidence = ''
870
+ for p in data['passages']:
871
+ if p['is_selected'] == 1:
872
+ evidence = p['passage_text']
873
+ if evidence == '' or '_' in q:
874
+ continue
875
+
876
+ o.write(json.dumps({
877
+ 'statement': fa,
878
+ 'question': q,
879
+ 'answer': a,
880
+ 'evidence': evidence
881
+ }) + '\n')
882
+
883
+
884
+ def create_qnli_ms_2():
885
+ files = ['train', 'dev']
886
+ for file in files:
887
+ o = open('../data/qnli/{}_ms_nw.jsonl'.format(file), 'w+')
888
+ fs = open('../data/qa/msmarco/statements_{}_ms.txt'.format(file))
889
+ statements = [s.rstrip('\n') for s in fs]
890
+ with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
891
+ n = 0
892
+ for line in f:
893
+ data = json.loads(line.rstrip('\n'))
894
+ if data['wellFormedAnswers'] != '[]' or data['query_type'] == 'DESCRIPTION':
895
+ continue
896
+ q = data['query']
897
+ a = data['answers'][0]
898
+
899
+ if a == 'No Answer Present.' or '_' in q:
900
+ continue
901
+
902
+ evidence = ''
903
+ for p in data['passages']:
904
+ if p['is_selected'] == 1:
905
+ evidence = p['passage_text']
906
+
907
+ if '\n' in q or '\n' in a:
908
+ nl = (q + a).count('\n') + 1
909
+ print(nl)
910
+ n += nl
911
+ continue
912
+
913
+ o.write(json.dumps({
914
+ 'statement': statements[n],
915
+ 'question': q,
916
+ 'answer': a,
917
+ 'evidence': evidence
918
+ }) + '\n')
919
+
920
+ n += 1
921
+
922
+
923
+ def create_msmarco_sg():
924
+ files = ['train', 'dev']
925
+ for file in files:
926
+ o = open('../data/qa/msmarco/{}_ms_sg.txt'.format(file), 'w+')
927
+ with open('../data/qa/msmarco/{}.jsonl'.format(file), 'r') as f:
928
+ for line in f:
929
+ data = json.loads(line.rstrip('\n'))
930
+ if data['wellFormedAnswers'] != '[]' or data['query_type'] == 'DESCRIPTION':
931
+ continue
932
+ q = data['query']
933
+ a = data['answers'][0]
934
+
935
+ if a == 'No Answer Present.' or '_' in q:
936
+ continue
937
+
938
+ if q.endswith('.'):
939
+ q = q[:len(q) - 1] + '?'
940
+ if not q.endswith('?'):
941
+ q += ' ?'
942
+
943
+ o.write(q + ' ' + a + '\n')
944
+ """
945
+ o.write(json.dumps({
946
+ 'statement': fa,
947
+ 'question': q,
948
+ 'answer': a,
949
+ 'evidence': evidence
950
+ }) + '\n')
951
+ """
952
+
953
+
954
+ def create_qar():
955
+ files = ['train', 'dev']
956
+ for file in files:
957
+ o = open('../data/qar/{}2.jsonl'.format(file), 'w+')
958
+
959
+ f = open('../data/qa/nq/{}_all.jsonl'.format(file))
960
+ for line in f:
961
+ d = json.loads(line.rstrip('\n'))
962
+ if d['evidence'] == '':
963
+ continue
964
+ o.write(json.dumps({
965
+ 'question': d['question'],
966
+ 'answer': d['evidence'],
967
+ 'source': 'natural-questions'
968
+ }) + '\n')
969
+
970
+ f = open('../data/qa/msmarco/{}_all.jsonl'.format(file))
971
+ for line in f:
972
+ d = json.loads(line.rstrip('\n'))
973
+ if d['evidence'] == '':
974
+ continue
975
+ o.write(json.dumps({
976
+ 'question': d['question'],
977
+ 'answer': d['evidence'],
978
+ 'source': 'msmarco'
979
+ }) + '\n')
980
+
981
+
982
+ """
983
+ question_inputs = get_inputs_text(texts=question_texts, tokenizer=tokenizer,
984
+ max_length=max_question_length)
985
+ answer_inputs = get_inputs_text(texts=answer_texts, tokenizer=tokenizer,
986
+ max_length=max_answer_length)
987
+
988
+ x = {
989
+ 'question_input_word_ids': question_inputs[0],
990
+ 'question_input_masks': question_inputs[1],
991
+ 'question_input_segments': question_inputs[2],
992
+ 'answer_input_word_ids': answer_inputs[0],
993
+ 'answer_input_masks': answer_inputs[1],
994
+ 'answer_input_segments': answer_inputs[2]
995
+ }
996
+ y = np.zeros((self.batch_size, self.batch_size))
997
+ """
998
+
999
+ """
1000
+ create_random_wiki_paragraphs('../data/wikipedia/wiki_paragraphs.jsonl',
1001
+ '../data/random_wiki_paragraphs.jsonl',
1002
+ total_samples=1000000,
1003
+ max_length=510)
1004
+ """
1005
+
1006
+ # preprocess_nq_dataset()
models/dense_retriever.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import logging
3
+ import pickle
4
+ import sqlite3
5
+ import struct
6
+
7
+ import numpy as np
8
+
9
+ from models.vector_index import VectorIndex
10
+
11
+ import random
12
+
13
+
14
+ class DenseRetriever:
15
+ def __init__(self, model, db_path, batch_size=64, use_gpu=False, debug=False):
16
+ self.model = model
17
+ self.vector_index = VectorIndex(768)
18
+ self.batch_size = batch_size
19
+ self.use_gpu = use_gpu
20
+ self.db = sqlite3.connect(db_path)
21
+ self.db.row_factory = sqlite3.Row
22
+ self.debug = debug
23
+
24
+ def load_pretrained_index(self, path):
25
+ self.vector_index.load(path)
26
+
27
+ def populate_index(self, table_name):
28
+ cur = self.db.cursor()
29
+ query = f'SELECT * FROM {table_name} ORDER BY idx' if not self.debug \
30
+ else f'SELECT * FROM {table_name} ORDER BY idx LIMIT 1000'
31
+ for r in cur.execute(query):
32
+ e = r['encoded']
33
+ v = [np.float32(struct.unpack('f', e[i*4:(i+1)*4])[0]) for i in range(int(len(e)/4))]
34
+ self.vector_index.index.add(np.ascontiguousarray([v]))
35
+ print(f"\rAdded {self.vector_index.index.ntotal} vectors", end='')
36
+ print()
37
+ logging.info("Finished adding vectors.")
38
+
39
+ def search(self, queries, limit=1000, probes=512, min_similarity=0):
40
+ query_vectors = self.model.encode(queries, batch_size=self.batch_size)
41
+ ids, similarities = self.vector_index.search(query_vectors, k=limit, probes=probes)
42
+ results = []
43
+ for j in range(len(ids)):
44
+ results.append([
45
+ (ids[j][i], similarities[j][i]) for i in range(len(ids[j])) if similarities[j][i] > min_similarity
46
+ ])
47
+ return results
models/encoder/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/encoder/0_BERT/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "output_past": true,
17
+ "pad_token_id": 0,
18
+ "type_vocab_size": 2,
19
+ "vocab_size": 28996
20
+ }
models/encoder/0_BERT/sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 256,
3
+ "do_lower_case": false
4
+ }
models/encoder/0_BERT/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
models/encoder/0_BERT/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": false, "max_len": 512, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
models/encoder/0_BERT/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/encoder/1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
models/encoder/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "__version__": "1.0"
3
+ }
models/encoder/modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "0_BERT",
6
+ "type": "sentence_transformers.models.BERT"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
models/encoder/stats.csv ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0.5779840120767146,0.8015233128623872,0.8583559199917659,0.8996466188630048,0.9569938587161629,0.6784459557674578
2
+ 0.5916046248327443,0.817116684392905,0.8708958040278588,0.9097505746732082,0.9613682368682883,0.6921505611721683
3
+ 0.6190860122825677,0.8418190551343191,0.8930078567262497,0.927230932857584,0.970237074141421,0.7180167815120542
4
+ 0.6270628195011494,0.8501046419871685,0.8993378392287371,0.9313308402236937,0.9730847085463341,0.7257908699942954
5
+ 0.6326380073420935,0.8536212989329948,0.9017051497581227,0.9325831131848904,0.9728273921844444,0.7306655399664306
6
+ 0.6432565958760764,0.8622328198442378,0.9086698459532714,0.9393248018664013,0.9762754314337667,0.740502524041259
7
+ 0.6463272377946272,0.8637252547431983,0.9105739870312554,0.94061138367585,0.9756750265893573,0.7430413705788191
8
+ 0.6501698287988472,0.8689573541016228,0.9141592616735856,0.9429100765087316,0.9773390057295777,0.7468761817476934
9
+ 0.6538580299859333,0.8707757230589769,0.9151542182728926,0.9432531649912512,0.9766699831886644,0.7501280666772407
10
+ 0.6536007136240436,0.8742066078841734,0.9184993309774591,0.9457920197618966,0.9773218513054517,0.751086942266862
11
+ 0.6515593371530518,0.8723539300785672,0.9152399903935224,0.9439908052286685,0.9773561601537036,0.7486908765241613
12
+ 0.6565512745737125,0.871650598689402,0.9158060863896799,0.9442824304388102,0.9770645349435619,0.7520715039306786
13
+ 0.6568600542079802,0.8752358733317323,0.919545750849144,0.9464781967269359,0.9773733145778296,0.7533833000405535
14
+ 0.6647682437300579,0.8794387072425979,0.922376230829931,0.9491371324664631,0.9790716025663019,0.7595162032832385
15
+ 0.6617490650838851,0.8790270010635743,0.919528596425018,0.9472158369643531,0.9780938003911208,0.7577154555045574
16
+ 0.6612172779359797,0.8765739184135588,0.9201290012694274,0.945397468006999,0.9761381960407589,0.756179860041224
17
+ 0.664836861426562,0.8797989501492435,0.9218787525302775,0.9485367276220538,0.9781452636634989,0.7597794971475365
18
+ 0.6645795450646722,0.8795073249391018,0.9216557450166398,0.9480220948982743,0.9778707928774831,0.7592882243172064
19
+ 0.6696229457577109,0.8815658558342196,0.9232682608844821,0.9485367276220538,0.9783511167530106,0.7631352909333269
20
+ 0.6653686485744673,0.8806052080831647,0.9231824887638522,0.9483480289566679,0.9792259923834357,0.7605593087121496
21
+ 0.6703262771468762,0.8829210553401722,0.924383298452671,0.9496689196143685,0.9781967269358768,0.763841387028115
22
+ 0.6659004357223728,0.8810169142621882,0.9229594812502144,0.9479020139293924,0.9770816893676879,0.7606746681736507
23
+ 0.6761587813497101,0.8852025937489278,0.9259100421998834,0.9505437952447936,0.9787456685079082,0.7682947389867827
24
+ 0.6741688681510961,0.8837273132740934,0.9250866298418362,0.9504580231241637,0.9782996534806326,0.7671206271858944
25
+ 0.6687137612790338,0.882183415102755,0.9244862249974268,0.9500120080968882,0.9785569698425224,0.7634663607979358
26
+ 0.6759014649878203,0.8859745428345971,0.9264418293477888,0.9514529797234706,0.9791745291110577,0.7687004551272422
27
+ 0.6738600885168284,0.8873640511888016,0.9274367859470958,0.9514015164510927,0.9790887569904279,0.7675660379918043
28
+ 0.6721446461042303,0.8834185336398257,0.9238858201530175,0.9491885957388411,0.9771503070641918,0.7653706174723683
29
+ 0.673396919065427,0.8885991697258723,0.9287233677565444,0.9533399663773288,0.9803581843757505,0.7681142393185553
30
+ 0.677205201221395,0.8848766596905342,0.9255669537173637,0.9494802209489828,0.9775963220914674,0.7691179334267018
31
+ 0.6725906611315058,0.8832984526709439,0.9240916732425293,0.9494116032524789,0.9772189247606958,0.765669866288031
32
+ 0.679126496723505,0.8885648608776203,0.9288091398771743,0.9522249288091399,0.978900058325042,0.7717041464381332
33
+ 0.6765018698322297,0.8865749476790065,0.9267849178303085,0.9513500531787148,0.9794146910488215,0.7692736782586759
34
+ 0.6792122688441349,0.8870724259786599,0.9285689779394106,0.9536144371633444,0.9801351768621127,0.771438903068368
35
+ 0.6806360860465914,0.8879816104573369,0.9286204412117885,0.9532370398325728,0.9795690808659553,0.772461625855239
36
+ 0.677170892373143,0.8858716162898411,0.926287439530655,0.9515044429958486,0.9794661543211994,0.7698102128989501
37
+ 0.685816722132638,0.8909493258311318,0.9305588911380245,0.9536830548598484,0.9806498095858922,0.776847345661558
38
+ 0.6874292380004803,0.8926304593954781,0.9318969362198511,0.9543177685525097,0.9809757436442859,0.778282769703823
39
+ 0.6828490067588431,0.8883418533639825,0.9285003602429066,0.9527567159570454,0.9795176175935774,0.774156750244517
40
+ 0.6826946169417093,0.890726318317494,0.9309705973170481,0.954986791093423,0.980889971523656,0.7746702901532684
41
+ 0.679143651147631,0.8890451847531479,0.9278313377019933,0.9516759872371084,0.9774419322743335,0.771751324626232
42
+ 0.6816310426458984,0.8901602223213366,0.9304731190173946,0.9536144371633444,0.9792088379593097,0.7738497806623483
43
+ 0.682986242151851,0.8900229869283288,0.929066456239064,0.95205338456788,0.9785741242666484,0.7741901341300708
44
+ 0.6789549524822451,0.8901945311695887,0.9288949119978043,0.9527052526846673,0.9786084331149003,0.7717603788208348
45
+ 0.687738017634748,0.8921844443682025,0.9321027893093629,0.9541290698871239,0.9797749339554671,0.778173339236671
46
+ 0.6816653514941503,0.8897313617181871,0.9295124712663396,0.9536487460115964,0.9787456685079082,0.7736270773705498
47
+ 0.6899852471952517,0.8917212749168011,0.930095721686623,0.9531512677119429,0.9784197344495146,0.7792509607389516
48
+ 0.68392973547878,0.8919271280063128,0.9309191340446701,0.9545922393385254,0.9804096476481284,0.7758120806040231
49
+ 0.687772326483,0.8929735478779978,0.9319483994922291,0.9544893127937695,0.980907125947782,0.7784736944841473
50
+ 0.6842042062647957,0.8922359076405805,0.9316224654338354,0.9548838645486671,0.980804199403026,0.7761278134585238
51
+ 0.6846159124438193,0.8934538717535252,0.9321199437334888,0.9548495557004152,0.9802381034068687,0.7768488500009861
52
+ 0.6882354959344015,0.8937626513877929,0.9338010772978351,0.9569423954437849,0.9811815967337977,0.7791075653748736
53
+ 0.6855250969224963,0.8929049301814939,0.9323086423988747,0.9555871959378324,0.9802895666792466,0.77714622677133
54
+ 0.6878752530277559,0.8921501355199506,0.9322057158541188,0.955021099941675,0.9800494047414828,0.778441356120615
55
+ 0.684873228805709,0.8899543692318249,0.9298212509006073,0.9535458194668405,0.9789686760215459,0.7759056803242941
56
+ 0.684890383229835,0.8920815178234467,0.931982708340481,0.9547980924280371,0.9799979414691049,0.7768483470243297
57
+ 0.6939479191683535,0.8962157340378083,0.9336466874807012,0.9556901224825882,0.9813531409750574,0.7834821941604841
58
+ 0.6864514358252993,0.8933680996328953,0.9311936048306858,0.9543863862490136,0.9796376985624593,0.7777730592023087
59
+ 0.6901567914365114,0.8946889902905959,0.9330634370604178,0.9547980924280371,0.9795347720177033,0.7806016928782297
60
+ 0.6914776820942121,0.8939513500531787,0.9313479946478197,0.9534600473462106,0.9786598963872782,0.7812826686600813
61
+ 0.6920266236662436,0.894637527018218,0.9334408343911895,0.9557758946032182,0.9812502144303016,0.781829703594363
62
+ 0.6930215802655505,0.8951521597419975,0.9326517308813943,0.9553298795759426,0.9809242803719079,0.7825955960750406
63
+ 0.6919923148179915,0.8950149243489895,0.9326517308813943,0.9554499605448246,0.9804096476481284,0.781802873512625
64
+ 0.6883727313274094,0.8899543692318249,0.9277970288537414,0.9525508628675335,0.9770816893676879,0.7779526954120721
65
+ 0.6861769650392836,0.8913781864342813,0.9306275088345284,0.9539575256458641,0.9791059114145538,0.7770877463274903
66
+ 0.6918550794249837,0.8962328884619344,0.9347617250488901,0.9580059697395958,0.9829313479946479,0.7821665815849399
67
+ 0.6930901979620544,0.8964044327031941,0.9333550622705595,0.9554671149689505,0.9809414347960339,0.7827637999409208
68
+ 0.6845816035955673,0.8918928191580608,0.9293237726009538,0.9518303770542423,0.9778364840292312,0.775727886193286
69
+ 0.6941023089854873,0.8980684118434145,0.9342470923251106,0.9558788211479741,0.981833464850585,0.7839931362022297
70
+ 0.689573541016228,0.8941572031426905,0.9325831131848904,0.9552269530311868,0.9811815967337977,0.7801137155274634
71
+ 0.6899680927711257,0.891807047037431,0.9305760455621505,0.9524136274745256,0.9781452636634989,0.7794242917082135
72
+ 0.6926441829347789,0.895598174769273,0.9327889662744022,0.9556215047860843,0.9811472878855457,0.7823044597371529
73
+ 0.6946855594057707,0.8966789034892099,0.9346244896558823,0.9555871959378324,0.9814217586715615,0.7841933350778004
74
+ 0.690654269736165,0.8944659827769582,0.9323257968230007,0.9552441074553127,0.9805297286170104,0.7806231726203139
75
+ 0.6921810134833773,0.895649638041651,0.9332006724534258,0.955913129996226,0.9812159055820496,0.7820928716391452
76
+ 0.6918207705767317,0.895529557072769,0.9342985555974886,0.9558445122997221,0.9821250900607267,0.7823101370916966
77
+ 0.6965382372113769,0.899440765773493,0.9366658661268741,0.9581775139808557,0.9829313479946479,0.7861867148449412
78
+ 0.6914433732459602,0.8950320787731156,0.9328232751226542,0.9559817476927299,0.9814389130956874,0.7817211414999055
79
+ 0.6889731361718188,0.8940542765979346,0.931999862764607,0.9545922393385254,0.9805983463135143,0.7797127544519308
80
+ 0.6913061378529523,0.8949634610766116,0.9325144954883865,0.9561189830857378,0.9811815967337977,0.781525423126717
81
+ 0.6936048306858339,0.896473050399698,0.9326002676090164,0.9560846742374859,0.9814903763680654,0.7835110543483639
82
+ 0.6941023089854873,0.8942258208391944,0.9313308402236937,0.9537859814046042,0.978917212749168,0.7828756585433598
83
+ 0.6987168490753766,0.8988403609290836,0.9345901808076302,0.9566336158095172,0.9807527361306481,0.787165665870235
84
+ 0.69545750849144,0.8965416680962021,0.9326860397296463,0.9552269530311868,0.9796376985624593,0.7842072813024578
85
+ 0.6983737605928569,0.8981370295399184,0.9356537551034412,0.9574055648951865,0.9815075307921913,0.7869324771044554
86
+ 0.696383847394243,0.8968676021545957,0.9340583936597249,0.9562562184787456,0.9800150958932309,0.7851641911299555
87
+ 0.6974988849624318,0.8990119051703435,0.9364771674614883,0.9584348303427453,0.9822794798778605,0.7869184646060629
88
+ 0.6960750677599753,0.8977939410573987,0.9344872542628744,0.9556729680584622,0.979912169348475,0.7852065186352417
89
+ 0.6984938415617388,0.8997838542560126,0.936785947095756,0.9581775139808557,0.9815933029128212,0.787530466944267
90
+ 0.6984766871376128,0.8989261330497135,0.935876762617079,0.9571997118056746,0.9802037945586166,0.7871252651258276
91
+ 0.6960922221841013,0.8987888976567057,0.9356022918310632,0.956959549867911,0.9812330600061756,0.7855466480210304
92
+ 0.6972415686005421,0.8972278450612413,0.9343157100216146,0.9560675198133599,0.9810786701890417,0.7856690263774906
93
+ 0.6998661954918174,0.8990805228668474,0.9345901808076302,0.9566679246577693,0.9809242803719079,0.7876728318822217
94
+ 0.6978248190208255,0.8990290595944694,0.9359968435859608,0.9581260507084777,0.9816276117610732,0.7868541067616021
95
+ 0.6961265310323532,0.8981198751157924,0.933852540570213,0.9557415857549663,0.9791402202628058,0.7853667716098274
96
+ 0.7032113081963839,0.9011047449137133,0.9365972484303702,0.9581088962843517,0.9811472878855457,0.7905386234241963
97
+ 0.7025422856554705,0.8997323909836347,0.9361169245548427,0.9574227193193124,0.9810958246131677,0.7898037355362394
98
+ 0.6948399492229046,0.8977253233608947,0.9346587985041342,0.9565993069612653,0.9812845232785535,0.7844807566485703
99
+ 0.6989227021648883,0.8963358150066902,0.9337839228737091,0.9558959755721,0.9803067211033726,0.7864312940359843
100
+ 0.6998661954918174,0.8990462140185954,0.934796033897142,0.9556215047860843,0.979912169348475,0.7876511656582856
101
+ 0.6971729509040382,0.8959412632517927,0.9324115689436305,0.954008988918242,0.9792088379593097,0.7851343541574478
102
+ 0.6985967681064946,0.9002984869797921,0.9374892784849213,0.9583319037979895,0.9815933029128212,0.7875866297739705
103
+ 0.7022506604453288,0.8989947507462175,0.9346073352317563,0.9563419905993755,0.9803410299516245,0.7894850593461494
104
+ 0.699231481799156,0.8983600370535562,0.9344700998387484,0.9559817476927299,0.979946478196727,0.7872299076283827
105
+ 0.7014787113596597,0.9008817374000755,0.936820255944008,0.9582632861014856,0.981713383881703,0.7897078452703377
106
+ 0.6975846570830617,0.8972621539094933,0.933989775963221,0.9556558136343363,0.9795519264418293,0.7858221981230687
107
+ 0.6996260335540536,0.8993721480769891,0.9361855422513466,0.9576457268329502,0.980855662675404,0.7881251554498462
108
+ 0.6976361203554396,0.9005729577658078,0.936785947095756,0.9582804405256116,0.9816276117610732,0.7872117895320041
109
+ 0.7015816379044155,0.8990119051703435,0.9352077400761657,0.956856623323155,0.9794146910488215,0.7890413758536403
110
+ 0.7018217998421793,0.9008302741276976,0.9366315572786221,0.9581432051326036,0.9818677736988369,0.7899203650821491
111
+ 0.705956016056541,0.9030775036882012,0.9381583010258345,0.9595327134868082,0.981816310426459,0.7932645670386803
112
+ 0.7033828524376436,0.9035063642913508,0.9389817133838817,0.9598414931210759,0.9826568772086322,0.7920589796937257
113
+ 0.7020276529316911,0.9032147390812091,0.9387415514461179,0.9598586475452019,0.9823137887261124,0.7907239281404607
114
+ 0.7005523724568566,0.9007445020070676,0.9359968435859608,0.9564620715682575,0.9800665591656088,0.7889600463763118
115
+ 0.7032284626205099,0.9015336055168628,0.9372491165471575,0.957834425498336,0.9813188321268055,0.7909847811260302
116
+ 0.7028853741379902,0.9019624661200123,0.9383813085394723,0.9585549113116273,0.9822108621813566,0.7910605299160958
117
+ 0.7070195903523518,0.9029574227193193,0.9383641541153463,0.957851579922462,0.9813359865509315,0.7936546808268249
118
+ 0.7017531821456754,0.901276289154973,0.936820255944008,0.9575771091364463,0.9807870449789,0.7899696413107273
119
+ 0.7035029334065256,0.9027687240539335,0.9378323669674409,0.9588979997941469,0.9815761484886952,0.7913691735211271
120
+ 0.70689950938347,0.9027858784780595,0.9388101691426218,0.9591381617319107,0.981833464850585,0.7936974704832799
121
+ 0.703039763955124,0.903300511201839,0.9389817133838817,0.9593783236696745,0.9823137887261124,0.7913347277016449
122
+ 0.7016502556009194,0.9015336055168628,0.9378666758156928,0.9586235290081312,0.9816790750334511,0.7902131444981638
123
+ 0.7057158541187772,0.9025971798126737,0.9373177342436615,0.9591553161560367,0.9822794798778605,0.7929256361254804
124
+ 0.7009640786358802,0.9010875904895873,0.9357223727999451,0.9566336158095172,0.9807184272823961,0.7890500472662872
125
+ 0.7051669125467458,0.9037808350773665,0.9386043160531101,0.9593268603972964,0.9816619206093251,0.7929897654311266
126
+ 0.7071225168971078,0.9051017257350671,0.9408343911894878,0.9611280749305245,0.983703297080317,0.794594734873485
127
+ 0.708769341613202,0.9037636806532404,0.9391361032010156,0.9600130373623358,0.9819535458194668,0.794976115563586
128
+ 0.7083404810100525,0.9041582324081381,0.9384670806601022,0.9596184856074381,0.9822108621813566,0.7949226710372033
129
+ 0.7017703365698014,0.9011562081860912,0.9369574913370158,0.9579545064672179,0.9816790750334511,0.7897549305809078
130
+ 0.7036573232236594,0.9018766939993824,0.9378495213915669,0.9589323086423989,0.981850619274711,0.7916034369572444
131
+ 0.7054585377568875,0.9039523793186263,0.9384156173877243,0.9596184856074381,0.9826568772086322,0.7932860445089596
132
+ 0.706916663807596,0.903163275808831,0.938707242597866,0.9589837719147768,0.9820050090918447,0.7939775358361482
133
+ 0.7069852815040999,0.9050845713109411,0.9400624421038186,0.9609908395375167,0.9829485024187739,0.7945385075253434
134
+ 0.7062647956908087,0.903163275808831,0.9378152125433149,0.9587950732493911,0.9812673688544276,0.7934864404440743
135
+ 0.7067108107180842,0.9052904244004529,0.9403883761622123,0.9617456341990599,0.983703297080317,0.7943919250071902
136
+ 0.7090781212474697,0.9049301814938072,0.9391017943527635,0.9599958829382098,0.9820907812124747,0.7957028931092571
137
+ 0.7073626788348715,0.9060452190619961,0.9407657734929838,0.9618828695920678,0.9837890692009469,0.7950442483313808
138
+ 0.7094212097299893,0.9053075788245789,0.9405084571310941,0.9614025457165403,0.9832915909012935,0.7962251490469788
139
+ 0.7094898274264932,0.9045699385871616,0.9400109788314406,0.960699214327375,0.98279411260164,0.7961054241729613
140
+ 0.7089065770062098,0.9050159536144372,0.9404055305863382,0.9606305966308711,0.9826911860568841,0.7957211011304709
141
+ 0.7072940611383676,0.9039180704703743,0.9389302501115038,0.9589494630665248,0.9811129790372937,0.7943070084981082
142
+ 0.7095584451229973,0.905153189007445,0.9410745531272515,0.9616255532301781,0.9830857378117817,0.7966290988294025
143
+ 0.7099358424537688,0.9054448142175867,0.9414176416097711,0.9606820599032491,0.9828970391463958,0.7963943395382952
144
+ 0.7119772189247607,0.907040175661303,0.9419494287576766,0.9619686417126977,0.9832572820530415,0.7984175961150379
145
+ 0.7118056746835009,0.9052561155522009,0.9400109788314406,0.9593783236696745,0.9812845232785535,0.797624930996624
146
+ 0.7138127423062408,0.9064226163927677,0.9405084571310941,0.9602875081483515,0.9821593989089786,0.798959296397886
147
+ 0.7152194050845713,0.9079322057158541,0.9420695097265585,0.9617284797749339,0.9833258997495454,0.8005330807512997
148
+ 0.7125433149209182,0.9075548083850825,0.9420008920300545,0.9619686417126977,0.9836518338079391,0.7988329781501703
149
+ 0.7090438123992178,0.9046728651319175,0.9386043160531101,0.9594469413661784,0.9806326551617662,0.7953849481751337
150
+ 0.7113768140803514,0.9046042474354136,0.9385700072048582,0.9591038528836587,0.9800665591656088,0.7968066498706435
151
+ 0.7120801454695166,0.9053590420969568,0.9400795965279446,0.9601331183312176,0.9816962294575771,0.7977260981425256
152
+ 0.7100559234226507,0.9055648951864685,0.9402682951933303,0.9606477510549971,0.981833464850585,0.7962313583525185
153
+ 0.7126977047380519,0.9069887123889251,0.9412117885202593,0.9620201049850756,0.9824681785432463,0.7985018939595462
154
+ 0.7109307990530758,0.9065770062099016,0.940645692524102,0.9599958829382098,0.9814903763680654,0.7971671390474596
155
+ 0.7122516897107765,0.9061138367585,0.94069715579648,0.9602017360277215,0.9814217586715615,0.7978698135709755
156
+ 0.7104676296016743,0.9054962774899646,0.940645692524102,0.9609565306892648,0.9817476927299551,0.7965702130914314
157
+ 0.7113253508079733,0.9059251380931143,0.9408687000377397,0.9609222218410128,0.9826054139362541,0.7971691791256932
158
+ 0.7113253508079733,0.9059594469413662,0.9399938244073146,0.9603904346931074,0.9814903763680654,0.7971970327904135
159
+ 0.7127663224345558,0.9066799327546574,0.9418293477887947,0.9615054722612962,0.9822108621813566,0.7987212070933885
160
+ 0.7143616838782723,0.9088928534669091,0.9430816207499915,0.962654818677737,0.9831543555082856,0.8000986555854859
161
+ 0.7152365595086972,0.9098878100662161,0.9432874738395032,0.962723436374241,0.9833430541736714,0.800961004414098
162
+ 0.7179641129447284,0.9099049644903421,0.9438021065632827,0.9623460390434693,0.9831886643565375,0.802537516552073
163
+ 0.7143102206058942,0.9087384636497753,0.9414347960338971,0.9615740899578001,0.9825882595121281,0.7999251770665431
164
+ 0.7152708683569493,0.9076577349298384,0.9418121933646687,0.9613682368682883,0.9823995608467424,0.8003246145804358
165
+ 0.7156997289600988,0.9076748893539643,0.9420523553024325,0.9611452293546505,0.9821765533331046,0.8005897851903018
166
+ 0.7151507873880674,0.9085497649843894,0.9421724362713144,0.9616255532301781,0.982776958177514,0.800298538385517
167
+ 0.7177582598552167,0.9090815521322949,0.9427385322674717,0.9619000240161938,0.9830514289635297,0.8021783842434324
168
+ 0.7155453391429649,0.9086183826808935,0.9423954437849521,0.9620372594092016,0.9829485024187739,0.8008689468904545
169
+ 0.7144817648471541,0.9089614711634131,0.9423268260884482,0.961642707654304,0.9829828112670258,0.800437947894952
170
+ 0.7129893299481936,0.9073146464473187,0.9409544721583697,0.9608192952962569,0.9825539506638762,0.7985213141537784
171
+ 0.7184444368202559,0.9101965897004838,0.9442652760146842,0.9625175832847291,0.9839777678663327,0.8031311330802511
172
+ 0.7183586646996261,0.9102823618211137,0.9441280406216763,0.9624661200123512,0.9839777678663327,0.8031333976315808
173
+ 0.71842728239613,0.9080351322606101,0.9417778845164168,0.9610594572340206,0.9825882595121281,0.8024804003813029
174
+ 0.7160085085943665,0.9082924486224997,0.9421895906954404,0.9607849864480049,0.9826397227845062,0.8007421064452407
175
+ 0.7192163859059252,0.9107283768483893,0.9447799087384636,0.9631008337050125,0.9839263045939548,0.8037904528738081
176
+ 0.7172950904038151,0.9081209043812399,0.942584142450338,0.9615397811095482,0.9822966343019864,0.801610435193648
177
+ 0.7180327306412324,0.9065255429375236,0.9410745531272515,0.9600130373623358,0.9811129790372937,0.8011702379380538
178
+ 0.7188561429992795,0.9108313033931451,0.9441966583181802,0.9632209146738944,0.9834459807184273,0.8034972168782052
179
+ 0.7175009434933269,0.9095790304319484,0.9428071499639757,0.9612653103235325,0.9821079356366007,0.8020695604970132
180
+ 0.7175009434933269,0.9095104127354444,0.9429272309328576,0.9617627886231859,0.9829313479946479,0.8021815489680041
181
+ 0.7151336329639414,0.9078121247469723,0.9413490239132672,0.9610423028098947,0.9826225683603801,0.800175479488273
182
+ 0.7194565478436888,0.9100593543074759,0.942566988026212,0.9621230315298316,0.9825196418156242,0.8033165975051533
183
+ 0.7187189076062717,0.9100593543074759,0.9426527601468419,0.9621744948022095,0.9823995608467424,0.8029329362048628
184
+ 0.7217552406765705,0.9117747967200741,0.944436820255944,0.9629807527361306,0.9831715099324115,0.8051882216053523
185
+ 0.7207431296531376,0.9112087007239167,0.9439908052286685,0.962654818677737,0.982742649329262,0.8044874117887816
186
+ 0.7204171955947439,0.9098706556420901,0.9441451950458023,0.9623288846193433,0.9828970391463958,0.8041066613024709
187
+ 0.7210004460150273,0.9107969945448932,0.9447799087384636,0.962689127525989,0.9828798847222698,0.8045389583450916
188
+ 0.7219096304937044,0.9123923559886095,0.9450886883727313,0.963632620852918,0.9833258997495454,0.8057297371947867
189
+ 0.7207431296531376,0.9095447215836965,0.9431845472947473,0.9618828695920678,0.9823309431502385,0.8041280599542056
190
+ 0.7218753216454523,0.9117919511442001,0.9450372251003534,0.9630665248567606,0.9829656568428998,0.8054852110655439
191
+ 0.7245342573849796,0.912855525440011,0.9447455998902117,0.96358115758054,0.9836346793838131,0.8073694350037202
192
+ 0.7228188149723814,0.9123752015644835,0.9446083644972039,0.9633753044910283,0.9835660616873092,0.8061516152582747
193
+ 0.7228016605482553,0.9123923559886095,0.945346004734621,0.9642501801214534,0.9842007753799705,0.8064601603251264
194
+ 0.7242426321748379,0.9127182900470031,0.9457577109136446,0.9643702610903352,0.9844924005901122,0.8075001427041348
195
+ 0.7249459635640032,0.9134730847085464,0.9458606374584005,0.9644217243627131,0.9844237828936082,0.8081538327780139
196
+ 0.7239338525405702,0.9126839811987512,0.9450200706762274,0.9633066867945244,0.9838405324733249,0.8070410918248734
197
+ 0.7236250729063025,0.9124266648368614,0.9448656808590936,0.9631866058256424,0.9836003705355612,0.806859810712907
198
+ 0.7224242632174838,0.911826259992452,0.9446941366178337,0.962723436374241,0.9832744364771675,0.8058241791370292
199
+ 0.722218410127972,0.9117404878718222,0.9443853569835661,0.962603355405359,0.9832401276289156,0.8056427934609479
200
+ 0.7226301163069956,0.9120149586578379,0.9448485264349675,0.9632552235221463,0.9833945174460493,0.8060240064701059
201
+ 0.7220297114625862,0.9115860980546883,0.9442138127423062,0.962689127525989,0.9828970391463958,0.805447713601981
202
+ 0.7237108450269324,0.9122722750197276,0.9448828352832196,0.9632723779462723,0.9833430541736714,0.8067396702769626
203
+ 0.7235736096339246,0.9124609736851134,0.9450715339486053,0.9634096133392802,0.9836346793838131,0.8067551081005684
204
+ 0.7245857206573575,0.913781864342814,0.9461007993961643,0.9646618863004769,0.9845781727107421,0.8077739135063374
205
+ 0.7238309259958143,0.9125124369574913,0.9448656808590936,0.9632895323703983,0.9833087453254195,0.8068440166640881
206
+ 0.7226815795793735,0.910934229937901,0.9437849521391567,0.962723436374241,0.9830171201152778,0.8056550566673708
207
+ 0.7228016605482553,0.9105911414553813,0.9439050331080385,0.9623803478917212,0.9831372010841596,0.8056326358407009
208
+ 0.7226129618828696,0.9121693484749717,0.9442652760146842,0.9631179881291385,0.9832915909012935,0.8061118839646912
209
+ 0.7240539335094521,0.9126153635022473,0.9447455998902117,0.9634096133392802,0.9834459807184273,0.8071669628005791
210
+ 0.7246200295056094,0.9130956873777747,0.9450886883727313,0.9632209146738944,0.9833773630219234,0.8075709724851876
211
+ 0.7247229560503654,0.9139190997358219,0.9458091741860226,0.9640786358801935,0.9842350842282225,0.8079037101451694
212
+ 0.7247058016262394,0.9128383710158849,0.9447970631625896,0.9633924589151542,0.9834802895666792,0.8075268394478943
213
+ 0.7238652348440663,0.9124609736851134,0.9447627543143376,0.9631351425532645,0.9833430541736714,0.8070234679633086
214
+ 0.7238652348440663,0.9124266648368614,0.9448828352832196,0.9633924589151542,0.9835489072631832,0.8069960834787101
215
+ 0.7244999485367276,0.9130785329536487,0.945397468006999,0.9638556283665557,0.9840635399869626,0.8075214826372252
216
+ 0.725117507805263,0.912889834288263,0.9453288503104951,0.9635983120046661,0.9838062236250729,0.8078535218173943
217
+ 0.7251518166535149,0.9133186948914125,0.9457405564895186,0.9640443270319415,0.9838919957457029,0.808052022558173
218
+ 0.725100353381137,0.9135931656774282,0.9460321816996603,0.9644045699385871,0.9844924005901122,0.8081660689818039
219
+ 0.7251346622293889,0.9133358493155385,0.9455347034000069,0.9641300991525714,0.9838405324733249,0.8080744792082037
220
+ 0.7249631179881292,0.9133701581637904,0.945397468006999,0.9638556283665557,0.9838233780491988,0.8079103037460894
models/nli.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import math
5
+ import argparse
6
+ import transformers
7
+ import logging
8
+
9
+ from torch import nn
10
+ from torch.nn import DataParallel
11
+ from models.data_utils import build_batch, LoggingHandler, get_examples
12
+ from tqdm import tqdm
13
+ from sklearn.metrics import precision_recall_fscore_support
14
+ from transformers import AutoConfig, RobertaModel, AutoModel, AutoTokenizer, AdamW
15
+
16
+ import torch.nn.functional as F
17
+ import numpy as np
18
+
19
+
20
+ class NLI(nn.Module):
21
+ """
22
+ NLI model based on BERT (using the code from: https://github.com/yg211/bert_nli)
23
+ """
24
+
25
+ def __init__(self, model_path=None, device='cuda', parallel=False, debug=False, label_num=3, batch_size=16):
26
+ super(NLI, self).__init__()
27
+
28
+ lm = 'roberta-large'
29
+
30
+ if model_path is not None:
31
+ configuration = AutoConfig.from_pretrained(lm)
32
+ self.bert = RobertaModel(configuration)
33
+ else:
34
+ self.bert = AutoModel.from_pretrained(lm)
35
+
36
+ self.tokenizer = AutoTokenizer.from_pretrained(lm)
37
+ self.vdim = 1024
38
+ self.max_length = 256
39
+
40
+ self.nli_head = nn.Linear(self.vdim, label_num)
41
+ self.batch_size = batch_size
42
+
43
+ if parallel:
44
+ self.bert = DataParallel(self.bert)
45
+
46
+ # load trained model
47
+ if model_path is not None:
48
+ sdict = torch.load(model_path, map_location=lambda storage, loc: storage)
49
+ self.load_state_dict(sdict, strict=False)
50
+
51
+ self.to(device)
52
+ self.device = device
53
+
54
+ self.debug = debug
55
+
56
+ def load_model(self, sdict):
57
+ if self.gpu:
58
+ self.load_state_dict(sdict)
59
+ self.to('cuda')
60
+ else:
61
+ self.load_state_dict(sdict)
62
+
63
+ def forward(self, sent_pair_list):
64
+ all_probs = None
65
+ iterator = range(0, len(sent_pair_list), self.batch_size)
66
+ if self.debug:
67
+ iterator = tqdm(iterator, desc='batch')
68
+ for batch_idx in iterator:
69
+ probs = self.ff(sent_pair_list[batch_idx:batch_idx + self.batch_size]).data.cpu().numpy()
70
+ if all_probs is None:
71
+ all_probs = probs
72
+ else:
73
+ all_probs = np.append(all_probs, probs, axis=0)
74
+ labels = []
75
+ for pp in all_probs:
76
+ ll = np.argmax(pp)
77
+ if ll == 0:
78
+ labels.append('entailment')
79
+ elif ll == 1:
80
+ labels.append('contradiction')
81
+ else:
82
+ labels.append('neutral')
83
+ return labels, all_probs
84
+
85
+ def ff(self, sent_pair_list):
86
+ ids, types, masks = build_batch(self.tokenizer, sent_pair_list, max_len=self.max_length)
87
+ if ids is None:
88
+ return None
89
+ ids_tensor = torch.tensor(ids)
90
+ #ypes_tensor = torch.tensor(types)
91
+ masks_tensor = torch.tensor(masks)
92
+
93
+ ids_tensor = ids_tensor.to(self.device)
94
+ #types_tensor = types_tensor.to(self.device)
95
+ masks_tensor = masks_tensor.to(self.device)
96
+ # self.bert.to('cuda')
97
+ # self.nli_head.to('cuda')
98
+
99
+ cls_vecs = self.bert(input_ids=ids_tensor, attention_mask=masks_tensor)[1]
100
+ logits = self.nli_head(cls_vecs)
101
+ predict_probs = F.log_softmax(logits, dim=1)
102
+ return predict_probs
103
+
104
+ def save(self, output_path, config_dic=None, acc=None):
105
+ if acc is None:
106
+ model_name = 'nli_large_2.state_dict'
107
+ else:
108
+ model_name = 'nli_large_2_acc{}.state_dict'.format(acc)
109
+ opath = os.path.join(output_path, model_name)
110
+ if config_dic is None:
111
+ torch.save(self.state_dict(), opath)
112
+ else:
113
+ torch.save(config_dic, opath)
114
+
115
+
116
+ def get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
117
+ """
118
+ Returns the correct learning rate scheduler
119
+ """
120
+ scheduler = scheduler.lower()
121
+ if scheduler == 'constantlr':
122
+ return transformers.optimization.get_constant_schedule(optimizer)
123
+ elif scheduler == 'warmupconstant':
124
+ return transformers.optimization.get_constant_schedule_with_warmup(optimizer, warmup_steps)
125
+ elif scheduler == 'warmuplinear':
126
+ return transformers.optimization.get_linear_schedule_with_warmup(optimizer, warmup_steps, t_total)
127
+ elif scheduler == 'warmupcosine':
128
+ return transformers.optimization.get_cosine_schedule_with_warmup(optimizer, warmup_steps, t_total)
129
+ elif scheduler == 'warmupcosinewithhardrestarts':
130
+ return transformers.optimization.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, warmup_steps,
131
+ t_total)
132
+ else:
133
+ raise ValueError("Unknown scheduler {}".format(scheduler))
134
+
135
+
136
+ def train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
137
+ max_grad_norm, best_acc, model_save_path):
138
+ loss_fn = nn.CrossEntropyLoss()
139
+ model.train()
140
+
141
+ step_cnt = 0
142
+ for pointer in tqdm(range(0, len(train_data), batch_size), desc='training'):
143
+ step_cnt += 1
144
+ sent_pairs = []
145
+ labels = []
146
+ for i in range(pointer, pointer + batch_size):
147
+ if i >= len(train_data):
148
+ break
149
+ sents = train_data[i].get_texts()
150
+ sent_pairs.append(sents)
151
+ labels.append(train_data[i].get_label())
152
+ predicted_probs = model.ff(sent_pairs)
153
+ if predicted_probs is None:
154
+ continue
155
+ true_labels = torch.LongTensor(labels)
156
+ if gpu:
157
+ true_labels = true_labels.to('cuda')
158
+ loss = loss_fn(predicted_probs, true_labels)
159
+ if fp16:
160
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
161
+ scaled_loss.backward()
162
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
163
+ else:
164
+ loss.backward()
165
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
166
+
167
+ optimizer.step()
168
+ scheduler.step()
169
+ optimizer.zero_grad()
170
+
171
+ if step_cnt % 5000 == 0:
172
+ acc = evaluate(model, dev_data, mute=True)
173
+ logging.info('==> step {} dev acc: {}'.format(step_cnt, acc))
174
+ model.train() # model was in eval mode in evaluate(); re-activate the train mode
175
+ if acc > best_acc:
176
+ best_acc = acc
177
+ logging.info('Saving model...')
178
+ model.save(model_save_path, model.state_dict())
179
+
180
+ return best_acc
181
+
182
+
183
+ def parse_args():
184
+ ap = argparse.ArgumentParser("arguments for bert-nli training")
185
+ ap.add_argument('-b', '--batch_size', type=int, default=64, help='batch size')
186
+ ap.add_argument('-ep', '--epoch_num', type=int, default=10, help='epoch num')
187
+ ap.add_argument('--fp16', type=int, default=0, help='use apex mixed precision training (1) or not (0)')
188
+ ap.add_argument('--gpu', type=int, default=1, help='use gpu (1) or not (0)')
189
+ ap.add_argument('-ss', '--scheduler_setting', type=str, default='WarmupLinear',
190
+ choices=['WarmupLinear', 'ConstantLR', 'WarmupConstant', 'WarmupCosine',
191
+ 'WarmupCosineWithHardRestarts'])
192
+ ap.add_argument('-mg', '--max_grad_norm', type=float, default=1., help='maximum gradient norm')
193
+ ap.add_argument('-wp', '--warmup_percent', type=float, default=0.1,
194
+ help='how many percentage of steps are used for warmup')
195
+
196
+ args = ap.parse_args()
197
+ return args.batch_size, args.epoch_num, args.fp16, args.gpu, args.scheduler_setting, args.max_grad_norm, args.warmup_percent
198
+
199
+
200
+ def evaluate(model, test_data, mute=False):
201
+ model.eval()
202
+ sent_pairs = [test_data[i].get_texts() for i in range(len(test_data))]
203
+ all_labels = [test_data[i].get_label() for i in range(len(test_data))]
204
+ _, probs = model(sent_pairs)
205
+ all_predict = [np.argmax(pp) for pp in probs]
206
+ assert len(all_predict) == len(all_labels)
207
+
208
+ acc = len([i for i in range(len(all_labels)) if all_predict[i] == all_labels[i]]) * 1. / len(all_labels)
209
+ prf = precision_recall_fscore_support(all_labels, all_predict, average=None, labels=[0, 1])
210
+
211
+ if not mute:
212
+ print('==>acc<==', acc)
213
+ print('==>precision-recall-f1<==\n', prf)
214
+
215
+ return acc
216
+
217
+
218
+ if __name__ == '__main__':
219
+
220
+ batch_size, epoch_num, fp16, gpu, scheduler_setting, max_grad_norm, warmup_percent = parse_args()
221
+ fp16 = bool(fp16)
222
+ gpu = bool(gpu)
223
+
224
+ print('=====Arguments=====')
225
+ print('batch size:\t{}'.format(batch_size))
226
+ print('epoch num:\t{}'.format(epoch_num))
227
+ print('fp16:\t{}'.format(fp16))
228
+ print('gpu:\t{}'.format(gpu))
229
+ print('scheduler setting:\t{}'.format(scheduler_setting))
230
+ print('max grad norm:\t{}'.format(max_grad_norm))
231
+ print('warmup percent:\t{}'.format(warmup_percent))
232
+ print('=====Arguments=====')
233
+
234
+ label_num = 3
235
+ model_save_path = 'weights/entailment'
236
+
237
+ logging.basicConfig(format='%(asctime)s - %(message)s',
238
+ datefmt='%Y-%m-%d %H:%M:%S',
239
+ level=logging.INFO,
240
+ handlers=[LoggingHandler()])
241
+
242
+ # Read the dataset
243
+ train_data = get_examples('../data/allnli/train.jsonl')
244
+ dev_data = get_examples('../data/allnli/dev.jsonl')
245
+
246
+ logging.info('train data size {}'.format(len(train_data)))
247
+ logging.info('dev data size {}'.format(len(dev_data)))
248
+ total_steps = math.ceil(epoch_num * len(train_data) * 1. / batch_size)
249
+ warmup_steps = int(total_steps * warmup_percent)
250
+
251
+ model = NLI(batch_size=batch_size, parallel=True)
252
+
253
+ optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-6, correct_bias=False)
254
+ scheduler = get_scheduler(optimizer, scheduler_setting, warmup_steps=warmup_steps, t_total=total_steps)
255
+ if fp16:
256
+ try:
257
+ from apex import amp
258
+ except ImportError:
259
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
260
+ model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
261
+
262
+ best_acc = -1.
263
+ for ep in range(epoch_num):
264
+ random.shuffle(train_data)
265
+ logging.info('\n=====epoch {}/{}====='.format(ep, epoch_num))
266
+ best_acc = train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
267
+ max_grad_norm, best_acc, model_save_path)
models/qa_ranker.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import logging
4
+ import math
5
+ import argparse
6
+ import copy
7
+
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import transformers
12
+
13
+ from torch import nn
14
+ from torch.nn import DataParallel
15
+ from transformers import BertModel, BertTokenizer
16
+
17
+ from .data_utils import build_batch, LoggingHandler, get_examples, get_qa_examples
18
+
19
+ from datetime import datetime
20
+ from tqdm import tqdm
21
+ from transformers import *
22
+ from nltk.tokenize import word_tokenize
23
+ from sklearn.metrics import precision_recall_fscore_support
24
+
25
+ #os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
26
+ #os.environ["CUDA_VISIBLE_DEVICES"] = "8,15"
27
+
28
+
29
+ class PassageRanker(nn.Module):
30
+ """Performs prediction, given the input of BERT embeddings.
31
+ """
32
+
33
+ def __init__(self, model_path=None, gpu=True, label_num=2, batch_size=16):
34
+ super(PassageRanker, self).__init__()
35
+
36
+ lm = 'bert-large-cased'
37
+
38
+ if model_path is not None:
39
+ configuration = AutoConfig.from_pretrained(lm)
40
+ self.language_model = BertModel(configuration)
41
+ else:
42
+ self.language_model = AutoModel.from_pretrained(lm)
43
+
44
+ self.language_model = DataParallel(self.language_model)
45
+
46
+ self.tokenizer = AutoTokenizer.from_pretrained(lm)
47
+ self.vdim = 1024
48
+ self.max_length = 256
49
+
50
+ self.classification_head = nn.Linear(self.vdim, label_num)
51
+ self.gpu = gpu
52
+ self.batch_size = batch_size
53
+
54
+ # load trained model
55
+ if model_path is not None:
56
+ if gpu:
57
+ sdict = torch.load(model_path)
58
+ self.load_state_dict(sdict, strict=False)
59
+ self.to('cuda')
60
+ else:
61
+ sdict = torch.load(model_path, map_location=lambda storage, loc: storage)
62
+ self.load_state_dict(sdict, strict=False)
63
+ else:
64
+ if self.gpu:
65
+ self.to('cuda')
66
+
67
+ def load_model(self, sdict):
68
+ if self.gpu:
69
+ self.load_state_dict(sdict)
70
+ self.to('cuda')
71
+ else:
72
+ self.load_state_dict(sdict)
73
+
74
+ def forward(self, sent_pair_list):
75
+ all_probs = None
76
+ for batch_idx in range(0, len(sent_pair_list), self.batch_size):
77
+ probs = self.ff(sent_pair_list[batch_idx:batch_idx + self.batch_size]).data.cpu().numpy()
78
+ if all_probs is None:
79
+ all_probs = probs
80
+ else:
81
+ all_probs = np.append(all_probs, probs, axis=0)
82
+ labels = []
83
+ for pp in all_probs:
84
+ ll = np.argmax(pp)
85
+ if ll == 0:
86
+ labels.append('relevant')
87
+ else:
88
+ labels.append('irrelevant')
89
+ return labels, all_probs
90
+
91
+ def ff(self, sent_pair_list):
92
+ ids, types, masks = build_batch(self.tokenizer, sent_pair_list, max_len=self.max_length)
93
+ if ids is None:
94
+ return None
95
+ ids_tensor = torch.tensor(ids)
96
+ types_tensor = torch.tensor(types)
97
+ masks_tensor = torch.tensor(masks)
98
+
99
+ if self.gpu:
100
+ ids_tensor = ids_tensor.to('cuda')
101
+ types_tensor = types_tensor.to('cuda')
102
+ masks_tensor = masks_tensor.to('cuda')
103
+ # self.bert.to('cuda')
104
+ # self.nli_head.to('cuda')
105
+
106
+ cls_vecs = self.language_model(input_ids=ids_tensor, token_type_ids=types_tensor, attention_mask=masks_tensor)[1]
107
+ logits = self.classification_head(cls_vecs)
108
+ predict_probs = F.log_softmax(logits, dim=1)
109
+ return predict_probs
110
+
111
+ def save(self, output_path, config_dic=None, acc=None):
112
+ if acc is None:
113
+ model_name = 'qa_ranker.state_dict'
114
+ else:
115
+ model_name = 'qa_ranker_acc{}.state_dict'.format(acc)
116
+ opath = os.path.join(output_path, model_name)
117
+ if config_dic is None:
118
+ torch.save(self.state_dict(), opath)
119
+ else:
120
+ torch.save(config_dic, opath)
121
+
122
+ @staticmethod
123
+ def load(input_path, gpu=True, label_num=2, batch_size=16):
124
+ if gpu:
125
+ sdict = torch.load(input_path)
126
+ else:
127
+ sdict = torch.load(input_path, map_location=lambda storage, loc: storage)
128
+ model = PassageRanker(gpu=gpu, label_num=label_num, batch_size=batch_size)
129
+ model.load_state_dict(sdict)
130
+ return model
131
+
132
+
133
+ def get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
134
+ """
135
+ Returns the correct learning rate scheduler
136
+ """
137
+ scheduler = scheduler.lower()
138
+ if scheduler == 'constantlr':
139
+ return transformers.optimization.get_constant_schedule(optimizer)
140
+ elif scheduler == 'warmupconstant':
141
+ return transformers.optimization.get_constant_schedule_with_warmup(optimizer, warmup_steps)
142
+ elif scheduler == 'warmuplinear':
143
+ return transformers.optimization.get_linear_schedule_with_warmup(optimizer, warmup_steps, t_total)
144
+ elif scheduler == 'warmupcosine':
145
+ return transformers.optimization.get_cosine_schedule_with_warmup(optimizer, warmup_steps, t_total)
146
+ elif scheduler == 'warmupcosinewithhardrestarts':
147
+ return transformers.optimization.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, warmup_steps,
148
+ t_total)
149
+ else:
150
+ raise ValueError("Unknown scheduler {}".format(scheduler))
151
+
152
+
153
+ def train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
154
+ max_grad_norm, best_acc, model_save_path):
155
+ loss_fn = nn.CrossEntropyLoss()
156
+ model.train()
157
+
158
+ step_cnt = 0
159
+ for pointer in tqdm(range(0, len(train_data), batch_size), desc='training'):
160
+ step_cnt += 1
161
+ sent_pairs = []
162
+ labels = []
163
+ for i in range(pointer, pointer + batch_size):
164
+ if i >= len(train_data):
165
+ break
166
+ sents = train_data[i].get_texts()
167
+ sent_pairs.append(sents)
168
+ labels.append(train_data[i].get_label())
169
+ predicted_probs = model.ff(sent_pairs)
170
+ if predicted_probs is None:
171
+ continue
172
+ true_labels = torch.LongTensor(labels)
173
+ if gpu:
174
+ true_labels = true_labels.to('cuda')
175
+ loss = loss_fn(predicted_probs, true_labels)
176
+ if fp16:
177
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
178
+ scaled_loss.backward()
179
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
180
+ else:
181
+ loss.backward()
182
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
183
+
184
+ optimizer.step()
185
+ scheduler.step()
186
+ optimizer.zero_grad()
187
+
188
+ if step_cnt % 5000 == 0:
189
+ acc = evaluate(model, dev_data, mute=True)
190
+ logging.info('==> step {} dev acc: {}'.format(step_cnt, acc))
191
+ model.train() # model was in eval mode in evaluate(); re-activate the train mode
192
+ if acc > best_acc:
193
+ best_acc = acc
194
+ logging.info('Saving model...')
195
+ model.save(model_save_path, model.state_dict())
196
+
197
+ return best_acc
198
+
199
+
200
+ def parse_args():
201
+ ap = argparse.ArgumentParser("arguments for bert-nli training")
202
+ ap.add_argument('-b', '--batch_size', type=int, default=128, help='batch size')
203
+ ap.add_argument('-ep', '--epoch_num', type=int, default=10, help='epoch num')
204
+ ap.add_argument('--fp16', type=int, default=0, help='use apex mixed precision training (1) or not (0)')
205
+ ap.add_argument('--gpu', type=int, default=1, help='use gpu (1) or not (0)')
206
+ ap.add_argument('-ss', '--scheduler_setting', type=str, default='WarmupLinear',
207
+ choices=['WarmupLinear', 'ConstantLR', 'WarmupConstant', 'WarmupCosine',
208
+ 'WarmupCosineWithHardRestarts'])
209
+ ap.add_argument('-mg', '--max_grad_norm', type=float, default=1., help='maximum gradient norm')
210
+ ap.add_argument('-wp', '--warmup_percent', type=float, default=0.1,
211
+ help='how many percentage of steps are used for warmup')
212
+
213
+ args = ap.parse_args()
214
+ return args.batch_size, args.epoch_num, args.fp16, args.gpu, args.scheduler_setting, args.max_grad_norm, args.warmup_percent
215
+
216
+
217
+ def evaluate(model, test_data, mute=False):
218
+ model.eval()
219
+ sent_pairs = [test_data[i].get_texts() for i in range(len(test_data))]
220
+ all_labels = [test_data[i].get_label() for i in range(len(test_data))]
221
+ _, probs = model(sent_pairs)
222
+ all_predict = [np.argmax(pp) for pp in probs]
223
+ assert len(all_predict) == len(all_labels)
224
+
225
+ acc = len([i for i in range(len(all_labels)) if all_predict[i] == all_labels[i]]) * 1. / len(all_labels)
226
+ prf = precision_recall_fscore_support(all_labels, all_predict, average=None, labels=[0, 1])
227
+
228
+ if not mute:
229
+ print('==>acc<==', acc)
230
+ print('label meanings: 0: relevant, 1: irrelevant')
231
+ print('==>precision-recall-f1<==\n', prf)
232
+
233
+ return acc
234
+
235
+
236
+ if __name__ == '__main__':
237
+
238
+ batch_size, epoch_num, fp16, gpu, scheduler_setting, max_grad_norm, warmup_percent = parse_args()
239
+ fp16 = bool(fp16)
240
+ gpu = bool(gpu)
241
+
242
+ print('=====Arguments=====')
243
+ print('batch size:\t{}'.format(batch_size))
244
+ print('epoch num:\t{}'.format(epoch_num))
245
+ print('fp16:\t{}'.format(fp16))
246
+ print('gpu:\t{}'.format(gpu))
247
+ print('scheduler setting:\t{}'.format(scheduler_setting))
248
+ print('max grad norm:\t{}'.format(max_grad_norm))
249
+ print('warmup percent:\t{}'.format(warmup_percent))
250
+ print('=====Arguments=====')
251
+
252
+ label_num = 2
253
+ model_save_path = 'weights/passage_ranker_2'
254
+
255
+ logging.basicConfig(format='%(asctime)s - %(message)s',
256
+ datefmt='%Y-%m-%d %H:%M:%S',
257
+ level=logging.INFO,
258
+ handlers=[LoggingHandler()])
259
+
260
+ # Read the dataset
261
+ train_data = get_qa_examples('../data/qa_ranking_large/train.jsonl', dev=False)
262
+ dev_data = get_qa_examples('../data/qa_ranking_large/dev.jsonl', dev=True)[:20000]
263
+
264
+ logging.info('train data size {}'.format(len(train_data)))
265
+ logging.info('dev data size {}'.format(len(dev_data)))
266
+ total_steps = math.ceil(epoch_num * len(train_data) * 1. / batch_size)
267
+ warmup_steps = int(total_steps * warmup_percent)
268
+
269
+ model = PassageRanker(gpu=gpu, batch_size=batch_size)
270
+
271
+ optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-6, correct_bias=False)
272
+ scheduler = get_scheduler(optimizer, scheduler_setting, warmup_steps=warmup_steps, t_total=total_steps)
273
+ if fp16:
274
+ try:
275
+ from apex import amp
276
+ except ImportError:
277
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
278
+ model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
279
+
280
+ best_acc = -1.
281
+ for ep in range(epoch_num):
282
+ logging.info('\n=====epoch {}/{}====='.format(ep, epoch_num))
283
+ best_acc = train(model, optimizer, scheduler, train_data, dev_data, batch_size, fp16, gpu,
284
+ max_grad_norm, best_acc, model_save_path)
models/sparse_retriever.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import logging
3
+ import math
4
+ import re
5
+
6
+ from multiprocessing import Pool
7
+ from nltk import WordNetLemmatizer, pos_tag
8
+ from nltk.corpus import wordnet, stopwords
9
+ from typing import List
10
+ from tqdm import tqdm
11
+
12
+ lemmatizer = WordNetLemmatizer()
13
+ stopwords = set(stopwords.words('english'))
14
+
15
+
16
+ def get_ngrams(text_tokens: List[str], min_length=1, max_length=4) -> List[str]:
17
+ """
18
+ Gets word-level ngrams from text
19
+ :param text_tokens: the string used to generate ngrams
20
+ :param min_length: the minimum length og the generated ngrams in words
21
+ :param max_length: the maximum length og the generated ngrams in words
22
+ :return: list of ngrams (strings)
23
+ """
24
+ max_length = min(max_length, len(text_tokens))
25
+ all_ngrams = []
26
+ for n in range(min_length - 1, max_length):
27
+ ngrams = [" ".join(ngram) for ngram in zip(*[text_tokens[i:] for i in range(n + 1)])]
28
+ for ngram in ngrams:
29
+ if '#' not in ngram:
30
+ all_ngrams.append(ngram)
31
+
32
+ return all_ngrams
33
+
34
+
35
+ def clean_text(text):
36
+ """
37
+ Removes non-alphanumeric symbols from text
38
+ :param text:
39
+ :return: clean text
40
+ """
41
+ text = text.replace('-', ' ')
42
+ text = re.sub('[^a-zA-Z0-9 ]+', '', text)
43
+ text = re.sub(' +', ' ', text)
44
+ return text
45
+
46
+
47
+ def string_hash(string):
48
+ """
49
+ Returns a static hash value for a string
50
+ :param string:
51
+ :return:
52
+ """
53
+ return int(hashlib.md5(string.encode('utf8')).hexdigest(), 16)
54
+
55
+
56
+ def tokenize(text, lemmatize=True, ngrams_length=2):
57
+ """
58
+ :param text:
59
+ :param stopwords: set of stopwords to exclude
60
+ :param lemmatize:
61
+ :param ngrams_length: the maximum number of tokens per ngram
62
+ :return:
63
+ """
64
+ tokens = clean_text(text).lower().split(' ')
65
+ tokens = [t for t in tokens if t != '']
66
+ if lemmatize:
67
+ lemmatized_words = []
68
+ pos_labels = pos_tag(tokens)
69
+ pos_labels = [pos[1][0].lower() for pos in pos_labels]
70
+
71
+ for i, word in enumerate(tokens):
72
+ if word in stopwords:
73
+ word = '#'
74
+ if pos_labels[i] == 'j':
75
+ pos_labels[i] = 'a' # 'j' <--> 'a' reassignment
76
+ if pos_labels[i] in ['r']: # For adverbs it's a bit different
77
+ try:
78
+ lemma = wordnet.synset(word + '.r.1').lemmas()[0].pertainyms()[0].name()
79
+ except:
80
+ lemma = word
81
+ lemmatized_words.append(lemma)
82
+ elif pos_labels[i] in ['a', 's', 'v']: # For adjectives and verbs
83
+ lemmatized_words.append(lemmatizer.lemmatize(word, pos=pos_labels[i]))
84
+ else: # For nouns and everything else as it is the default kwarg
85
+ lemmatized_words.append(lemmatizer.lemmatize(word))
86
+ tokens = lemmatized_words
87
+
88
+ ngrams = get_ngrams(tokens, max_length=ngrams_length)
89
+
90
+ return ngrams
91
+
92
+
93
+ class SparseRetriever:
94
+ def __init__(self, ngram_buckets=16777216, k1=1.5, b=0.75, epsilon=0.25,
95
+ max_relative_freq=0.5, workers=4):
96
+ self.ngram_buckets = ngram_buckets
97
+ self.k1 = k1
98
+ self.b = b
99
+ self.epsilon = epsilon
100
+ self.max_relative_freq = max_relative_freq
101
+ self.corpus_size = 0
102
+ self.avgdl = 0
103
+
104
+ self.inverted_index = {}
105
+ self.idf = {}
106
+ self.doc_len = []
107
+ self.workers = workers
108
+
109
+ def index_documents(self, documents):
110
+ with Pool(self.workers) as p:
111
+ tokenized_documents = list(tqdm(p.imap(tokenize, documents), total=len(documents), desc='tokenized'))
112
+
113
+ logging.info('Building inverted index...')
114
+
115
+ self.corpus_size = len(tokenized_documents)
116
+ nd = self._create_inverted_index(tokenized_documents)
117
+ self._calc_idf(nd)
118
+
119
+ logging.info('Built inverted index')
120
+
121
+ def _create_inverted_index(self, documents):
122
+ nd = {} # word -> number of documents with word
123
+ num_doc = 0
124
+ for doc_id, document in enumerate(tqdm(documents, desc='indexed')):
125
+ self.doc_len.append(len(document))
126
+ num_doc += len(document)
127
+
128
+ frequencies = {}
129
+ for word in document:
130
+ if word not in frequencies:
131
+ frequencies[word] = 0
132
+ frequencies[word] += 1
133
+
134
+ for word, freq in frequencies.items():
135
+ hashed_word = string_hash(word) % self.ngram_buckets
136
+ if hashed_word not in self.inverted_index:
137
+ self.inverted_index[hashed_word] = [(doc_id, freq)]
138
+ else:
139
+ self.inverted_index[hashed_word].append((doc_id, freq))
140
+
141
+ for word, freq in frequencies.items():
142
+ hashed_word = string_hash(word) % self.ngram_buckets
143
+ if hashed_word not in nd:
144
+ nd[hashed_word] = 0
145
+ nd[hashed_word] += 1
146
+
147
+ self.avgdl = num_doc / self.corpus_size
148
+
149
+ return nd
150
+
151
+ def _calc_idf(self, nd):
152
+ """
153
+ Calculates frequencies of terms in documents and in corpus.
154
+ This algorithm sets a floor on the idf values to eps * average_idf
155
+ """
156
+ # collect idf sum to calculate an average idf for epsilon value
157
+ idf_sum = 0
158
+ # collect words with negative idf to set them a special epsilon value.
159
+ # idf can be negative if word is contained in more than half of documents
160
+ negative_idfs = []
161
+ for word, freq in nd.items():
162
+ idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
163
+ r_freq = float(freq) / self.corpus_size
164
+ if r_freq > self.max_relative_freq:
165
+ continue
166
+ self.idf[word] = idf
167
+ idf_sum += idf
168
+ if idf < 0:
169
+ negative_idfs.append(word)
170
+ self.average_idf = idf_sum / len(self.idf)
171
+
172
+ eps = self.epsilon * self.average_idf
173
+ for word in negative_idfs:
174
+ self.idf[word] = eps
175
+
176
+ def _get_scores(self, query):
177
+ """
178
+ The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores,
179
+ this algorithm also adds a floor to the idf value of epsilon.
180
+ See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info
181
+ :param query:
182
+ :return:
183
+ """
184
+ query = tokenize(query)
185
+ scores = {}
186
+ for q in query:
187
+ hashed_word = string_hash(q) % self.ngram_buckets
188
+ idf = self.idf.get(hashed_word)
189
+ if idf:
190
+ doc_freqs = self.inverted_index[hashed_word]
191
+ for doc_id, freq in doc_freqs:
192
+ score = idf * (freq * (self.k1 + 1) /
193
+ (freq + self.k1 * 1 - self.b + (self.b * self.doc_len[doc_id] / self.avgdl)))
194
+ if doc_id in scores:
195
+ scores[doc_id] += score
196
+ else:
197
+ scores[doc_id] = score
198
+ return sorted(scores.items(), key=lambda x: x[1], reverse=True)
199
+
200
+ def search(self, queries, topk=100):
201
+ results = [self._get_scores(q) for q in tqdm(queries, desc='searched')]
202
+ results = [r[:topk] for r in results]
203
+ logging.info('Done searching')
204
+ return results
models/sparse_retriever_fast.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import tantivy
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ class SparseRetrieverFast:
9
+ def __init__(self, path='sparse_index', load=True):
10
+ if not os.path.exists(path):
11
+ os.mkdir(path)
12
+ schema_builder = tantivy.SchemaBuilder()
13
+ schema_builder.add_text_field("body", stored=False)
14
+ schema_builder.add_unsigned_field("doc_id", stored=True)
15
+ schema = schema_builder.build()
16
+ self.index = tantivy.Index(schema, path=path, reuse=load)
17
+ self.searcher = self.index.searcher()
18
+
19
+ def index_documents(self, documents):
20
+ logging.info('Building sparse index of {} docs...'.format(len(documents)))
21
+ writer = self.index.writer()
22
+ for i, doc in enumerate(documents):
23
+ writer.add_document(tantivy.Document(
24
+ body=[doc],
25
+ doc_id=i
26
+ ))
27
+ if (i+1) % 100000 == 0:
28
+ writer.commit()
29
+ logging.info('Indexed {} docs'.format(i+1))
30
+ writer.commit()
31
+ logging.info('Built sparse index')
32
+ self.index.reload()
33
+ self.searcher = self.index.searcher()
34
+
35
+ def search(self, queries, topk=100):
36
+ results = []
37
+ for q in tqdm(queries, desc='searched'):
38
+ docs = []
39
+ try:
40
+ query = self.index.parse_query(q, ["body"])
41
+ scores = self.searcher.search(query, topk).hits
42
+ docs = [(self.searcher.doc(doc_id)['doc_id'][0], score)
43
+ for score, doc_id in scores]
44
+ except:
45
+ pass
46
+ results.append(docs)
47
+
48
+ return results
models/text_encoder.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import json
3
+ import logging
4
+ import os
5
+
6
+ import hashlib
7
+ import pickle
8
+ import torch
9
+ torch.device(0)
10
+ import transformers
11
+
12
+ import numpy as np
13
+
14
+ from pprint import pprint
15
+ from typing import Iterable, Tuple, Type, Dict, List, Union, Optional
16
+ from numpy.core.multiarray import ndarray
17
+ from torch import nn
18
+ from torch.nn import DataParallel
19
+ from torch.optim import Optimizer
20
+ from torch.utils.data import DataLoader, Dataset
21
+
22
+ from models.data_utils import get_qnli_examples, get_single_examples, get_ict_examples, get_examples, get_qar_examples, \
23
+ get_qar_artificial_examples, get_retrieval_examples
24
+ from tqdm import tqdm
25
+
26
+ from collections import OrderedDict
27
+ from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
28
+
29
+ from models.vector_index import VectorIndex
30
+
31
+
32
+ class InputExample:
33
+ """
34
+ Structure for one input example with texts, the label and a unique id
35
+ """
36
+ def __init__(self, guid: str, texts: List[str], label: Union[int, float]):
37
+ """
38
+ Creates one InputExample with the given texts, guid and label
39
+
40
+ str.strip() is called on both texts.
41
+
42
+ :param guid
43
+ id for the example
44
+ :param texts
45
+ the texts for the example
46
+ :param label
47
+ the label for the example
48
+ """
49
+ self.guid = guid
50
+ self.texts = [text.strip() for text in texts]
51
+ self.label = label
52
+
53
+
54
+ class LoggingHandler(logging.Handler):
55
+ def __init__(self, level=logging.NOTSET):
56
+ super().__init__(level)
57
+
58
+ def emit(self, record):
59
+ try:
60
+ msg = self.format(record)
61
+ tqdm.write(msg)
62
+ self.flush()
63
+ except (KeyboardInterrupt, SystemExit):
64
+ raise
65
+ except:
66
+ self.handleError(record)
67
+
68
+
69
+ def import_from_string(dotted_path):
70
+ """
71
+ Import a dotted module path and return the attribute/class designated by the
72
+ last name in the path. Raise ImportError if the import failed.
73
+ """
74
+ try:
75
+ module_path, class_name = dotted_path.rsplit('.', 1)
76
+ except ValueError:
77
+ msg = "%s doesn't look like a module path" % dotted_path
78
+ raise ImportError(msg)
79
+
80
+ module = importlib.import_module(module_path)
81
+
82
+ try:
83
+ return getattr(module, class_name)
84
+ except AttributeError:
85
+ msg = 'Module "%s" does not define a "%s" attribute/class' % (module_path, class_name)
86
+ raise ImportError(msg)
87
+
88
+
89
+ def batch_to_device(batch, target_device: torch.device):
90
+ """
91
+ send a batch to a device
92
+
93
+ :param batch:
94
+ :param target_device:
95
+ :return: the batch sent to the device
96
+ """
97
+ features = batch['features']
98
+ for paired_sentence_idx in range(len(features)):
99
+ for feature_name in features[paired_sentence_idx]:
100
+ features[paired_sentence_idx][feature_name] = features[paired_sentence_idx][feature_name].to(target_device)
101
+
102
+ labels = batch['labels'].to(target_device)
103
+ return features, labels
104
+
105
+
106
+ class BERT(nn.Module):
107
+ """BERT model to generate token embeddings.
108
+ Each token is mapped to an output vector from BERT.
109
+ """
110
+ def __init__(self, model_name_or_path: str, max_seq_length: int = 128, do_lower_case: Optional[bool] = None, model_args: Dict = {}, tokenizer_args: Dict = {}):
111
+ super(BERT, self).__init__()
112
+ self.config_keys = ['max_seq_length', 'do_lower_case']
113
+ self.do_lower_case = do_lower_case
114
+
115
+ if max_seq_length > 510:
116
+ logging.warning("BERT only allows a max_seq_length of 510 (512 with special tokens). Value will be set to 510")
117
+ max_seq_length = 510
118
+ self.max_seq_length = max_seq_length
119
+
120
+ if self.do_lower_case is not None:
121
+ tokenizer_args['do_lower_case'] = do_lower_case
122
+
123
+ self.model = AutoModel.from_pretrained(model_name_or_path, **model_args)
124
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **tokenizer_args)
125
+
126
+ def forward(self, features):
127
+ """Returns token_embeddings, cls_token"""
128
+ output_states = self.model(**features)
129
+ output_tokens = output_states[0]
130
+ cls_tokens = output_tokens[:, 0, :] # CLS token is first token
131
+ features.update({'token_embeddings': output_tokens, 'cls_token_embeddings': cls_tokens, 'attention_mask': features['attention_mask']})
132
+
133
+ if len(output_states) > 2:
134
+ features.update({'all_layer_embeddings': output_states[2]})
135
+
136
+ return features
137
+
138
+ def get_word_embedding_dimension(self) -> int:
139
+ return self.model.config.hidden_size
140
+
141
+ def tokenize(self, text: str) -> List[int]:
142
+ """
143
+ Tokenizes a text and maps tokens to token-ids
144
+ """
145
+ return self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
146
+
147
+ def get_sentence_features(self, tokens: List[int], pad_seq_length: int):
148
+ """
149
+ Convert tokenized sentence in its embedding ids, segment ids and mask
150
+ :param tokens:
151
+ a tokenized sentence
152
+ :param pad_seq_length:
153
+ the maximal length of the sequence. Cannot be greater than self.sentence_transformer_config.max_seq_length
154
+ :return: embedding ids, segment ids and mask for the sentence
155
+ """
156
+ pad_seq_length = min(pad_seq_length, self.max_seq_length) + 2 ##Add Space for CLS + SEP token
157
+
158
+ return self.tokenizer.prepare_for_model(tokens, max_length=pad_seq_length, pad_to_max_length=True, return_tensors='pt')
159
+
160
+ def get_config_dict(self):
161
+ return {key: self.__dict__[key] for key in self.config_keys}
162
+
163
+ def save(self, output_path: str):
164
+ self.model.save_pretrained(output_path)
165
+ self.tokenizer.save_pretrained(output_path)
166
+
167
+ with open(os.path.join(output_path, 'sentence_bert_config.json'), 'w') as fOut:
168
+ json.dump(self.get_config_dict(), fOut, indent=2)
169
+
170
+ @staticmethod
171
+ def load(input_path: str):
172
+ with open(os.path.join(input_path, 'sentence_bert_config.json')) as fIn:
173
+ config = json.load(fIn)
174
+ return BERT(model_name_or_path=input_path, **config)
175
+
176
+
177
+ class SentenceTransformer(nn.Sequential):
178
+ def __init__(self, model_path: str = None, modules: Iterable[nn.Module] = None, device: str = None,
179
+ parallel=False):
180
+ if modules is not None and not isinstance(modules, OrderedDict):
181
+ modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
182
+
183
+ if model_path is not None:
184
+ logging.info("Load SentenceTransformer from folder: {}".format(model_path))
185
+
186
+ with open(os.path.join(model_path, 'modules.json')) as fIn:
187
+ contained_modules = json.load(fIn)
188
+
189
+ modules = OrderedDict()
190
+ for module_config in contained_modules:
191
+ module_class = import_from_string(module_config['type'])
192
+ module = module_class.load(os.path.join(model_path, module_config['path']))
193
+ if 'BERT' in module_config['type']:
194
+ if parallel:
195
+ module = DataParallel(module)
196
+ modules[module_config['name']] = module
197
+
198
+ super().__init__(modules)
199
+
200
+ self.best_score = -1
201
+ self.total_steps = 0
202
+
203
+ if device is None:
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ logging.info("Use pytorch device: {}".format(device))
206
+ self.device = torch.device(device)
207
+ self.to(device)
208
+
209
+ def _eval_during_training(self, evaluator, output_path, save_best_model, epoch, steps):
210
+ """Runs evaluation during the training"""
211
+ if evaluator is not None:
212
+ score = evaluator(self, output_path=output_path, epoch=epoch, steps=steps)
213
+ print(score)
214
+ if score > self.best_score and save_best_model:
215
+ print('saving')
216
+ self.save(output_path)
217
+ self.best_score = score
218
+
219
+ def evaluate(self, evaluator, output_path: str = None):
220
+ """
221
+ Evaluate the model
222
+
223
+ :param evaluator:
224
+ the evaluator
225
+ :param output_path:
226
+ the evaluator can write the results to this path
227
+ """
228
+ if output_path is not None:
229
+ os.makedirs(output_path, exist_ok=True)
230
+ return evaluator(self, output_path)
231
+
232
+ def encode(self, sentences: List[str], batch_size: int = 8, show_progress_bar: bool = None, output_value: str = 'sentence_embedding', convert_to_numpy: bool = True) -> List[ndarray]:
233
+ """
234
+ Computes sentence embeddings
235
+ :param sentences:
236
+ the sentences to embed
237
+ :param batch_size:
238
+ the batch size used for the computation
239
+ :param show_progress_bar:
240
+ Output a progress bar when encode sentences
241
+ :param output_value:
242
+ Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings
243
+ to get wordpiece token embeddings.
244
+ :param convert_to_numpy:
245
+ If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
246
+ :return:
247
+ Depending on convert_to_numpy, either a list of numpy vectors or a list of pytorch tensors
248
+ """
249
+ self.eval()
250
+ if show_progress_bar is None:
251
+ show_progress_bar = (logging.getLogger().getEffectiveLevel()==logging.INFO or logging.getLogger().getEffectiveLevel()==logging.DEBUG)
252
+
253
+ all_embeddings = []
254
+ length_sorted_idx = np.argsort([len(sen) for sen in sentences])
255
+
256
+ iterator = range(0, len(sentences), batch_size)
257
+ if show_progress_bar:
258
+ iterator = tqdm(iterator, desc="Batches")
259
+
260
+ for batch_idx in iterator:
261
+ batch_tokens = []
262
+
263
+ batch_start = batch_idx
264
+ batch_end = min(batch_start + batch_size, len(sentences))
265
+
266
+ longest_seq = 0
267
+
268
+ for idx in length_sorted_idx[batch_start: batch_end]:
269
+ sentence = sentences[idx]
270
+ tokens = self.tokenize(sentence)
271
+ longest_seq = max(longest_seq, len(tokens))
272
+ batch_tokens.append(tokens)
273
+
274
+ features = {}
275
+ for text in batch_tokens:
276
+ sentence_features = self.get_sentence_features(text, longest_seq)
277
+
278
+ for feature_name in sentence_features:
279
+ if feature_name not in features:
280
+ features[feature_name] = []
281
+ features[feature_name].append(sentence_features[feature_name])
282
+
283
+ for feature_name in features:
284
+ try:
285
+ features[feature_name] = torch.tensor(np.asarray(features[feature_name])).to(self.device)
286
+ except:
287
+ features[feature_name] = torch.cat(features[feature_name]).to(self.device)
288
+
289
+ with torch.no_grad():
290
+ out_features = self.forward(features)
291
+ embeddings = out_features[output_value]
292
+
293
+ if output_value == 'token_embeddings':
294
+ #Set token embeddings to 0 for padding tokens
295
+ input_mask = out_features['attention_mask']
296
+ input_mask_expanded = input_mask.unsqueeze(-1).expand(embeddings.size()).float()
297
+ embeddings = embeddings * input_mask_expanded
298
+
299
+ if convert_to_numpy:
300
+ embeddings = embeddings.to('cpu').numpy()
301
+
302
+ all_embeddings.extend(embeddings)
303
+
304
+ reverting_order = np.argsort(length_sorted_idx)
305
+ all_embeddings = [all_embeddings[idx] for idx in reverting_order]
306
+
307
+ return all_embeddings
308
+
309
+ def get_max_seq_length(self):
310
+ if hasattr(self._first_module(), 'max_seq_length'):
311
+ return self._first_module().max_seq_length
312
+
313
+ return None
314
+
315
+ def tokenize(self, text):
316
+ return self._first_module().tokenize(text)
317
+
318
+ def get_sentence_features(self, *features):
319
+ return self._first_module().get_sentence_features(*features)
320
+
321
+ def get_sentence_embedding_dimension(self):
322
+ return self._last_module().get_sentence_embedding_dimension()
323
+
324
+ def _first_module(self):
325
+ """Returns the first module of this sequential embedder"""
326
+ try:
327
+ return self._modules[next(iter(self._modules))].module
328
+ except:
329
+ return self._modules[next(iter(self._modules))]
330
+
331
+ def _last_module(self):
332
+ """Returns the last module of this sequential embedder"""
333
+ return self._modules[next(reversed(self._modules))]
334
+
335
+ def save(self, path):
336
+ """
337
+ Saves all elements for this seq. sentence embedder into different sub-folders
338
+ """
339
+ if path is None:
340
+ return
341
+
342
+ logging.info("Save model to {}".format(path))
343
+ contained_modules = []
344
+
345
+ for idx, name in enumerate(self._modules):
346
+ module = self._modules[name]
347
+ if isinstance(module, DataParallel):
348
+ module = module.module
349
+ model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)
350
+ os.makedirs(model_path, exist_ok=True)
351
+ module.save(model_path)
352
+ contained_modules.append(
353
+ {'idx': idx, 'name': name, 'path': os.path.basename(model_path), 'type': type(module).__module__})
354
+
355
+ with open(os.path.join(path, 'modules.json'), 'w') as fOut:
356
+ json.dump(contained_modules, fOut, indent=2)
357
+
358
+ with open(os.path.join(path, 'config.json'), 'w') as fOut:
359
+ json.dump({'__version__': '1.0'}, fOut, indent=2)
360
+
361
+ def smart_batching_collate(self, batch):
362
+ """
363
+ Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model
364
+ :param batch:
365
+ a batch from a SmartBatchingDataset
366
+ :return:
367
+ a batch of tensors for the model
368
+ """
369
+ num_texts = len(batch[0][0])
370
+
371
+ labels = []
372
+ paired_texts = [[] for _ in range(num_texts)]
373
+ max_seq_len = [0] * num_texts
374
+ for tokens, label in batch:
375
+ labels.append(label)
376
+ for i in range(num_texts):
377
+ paired_texts[i].append(tokens[i])
378
+ max_seq_len[i] = max(max_seq_len[i], len(tokens[i]))
379
+
380
+ features = []
381
+ for idx in range(num_texts):
382
+ max_len = max_seq_len[idx]
383
+ feature_lists = {}
384
+
385
+ for text in paired_texts[idx]:
386
+ sentence_features = self.get_sentence_features(text, max_len)
387
+
388
+ for feature_name in sentence_features:
389
+ if feature_name not in feature_lists:
390
+ feature_lists[feature_name] = []
391
+
392
+ feature_lists[feature_name].append(sentence_features[feature_name])
393
+
394
+ for feature_name in feature_lists:
395
+ try:
396
+ feature_lists[feature_name] = torch.tensor(np.asarray(feature_lists[feature_name]))
397
+ except:
398
+ feature_lists[feature_name] = torch.cat(feature_lists[feature_name])
399
+
400
+ features.append(feature_lists)
401
+
402
+ return {'features': features, 'labels': torch.stack(labels)}
403
+
404
+ def fit(self,
405
+ train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
406
+ evaluator,
407
+ epochs: int = 1,
408
+ steps_per_epoch=None,
409
+ scheduler: str = 'WarmupLinear',
410
+ warmup_steps: int = 10000,
411
+ optimizer_class: Type[Optimizer] = transformers.AdamW,
412
+ optimizer_params: Dict[str, object] = {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False},
413
+ acc_steps: int = 4,
414
+ weight_decay: float = 0.01,
415
+ evaluation_steps: int = 0,
416
+ output_path: str = None,
417
+ save_best_model: bool = True,
418
+ max_grad_norm: float = 1,
419
+ fp16: bool = False,
420
+ fp16_opt_level: str = 'O1',
421
+ local_rank: int = -1
422
+ ):
423
+ if output_path is not None:
424
+ os.makedirs(output_path, exist_ok=True)
425
+
426
+ dataloaders = [dataloader for dataloader, _ in train_objectives]
427
+
428
+ # Use smart batching
429
+ for dataloader in dataloaders:
430
+ dataloader.collate_fn = self.smart_batching_collate
431
+
432
+ loss_models = [loss for _, loss in train_objectives]
433
+ device = self.device
434
+
435
+ for loss_model in loss_models:
436
+ loss_model.to(device)
437
+
438
+ if steps_per_epoch is None or steps_per_epoch == 0:
439
+ steps_per_epoch = min([len(dataloader) for dataloader in dataloaders])
440
+
441
+ num_train_steps = int(steps_per_epoch * epochs)
442
+
443
+ # Prepare optimizers
444
+ optimizers = []
445
+ schedulers = []
446
+ for loss_model in loss_models:
447
+ param_optimizer = list(loss_model.named_parameters())
448
+
449
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
450
+ optimizer_grouped_parameters = [
451
+ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
452
+ 'weight_decay': weight_decay},
453
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
454
+ ]
455
+ t_total = num_train_steps
456
+ if local_rank != -1:
457
+ t_total = t_total // torch.distributed.get_world_size()
458
+
459
+ optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)
460
+ scheduler_obj = self._get_scheduler(optimizer, scheduler=scheduler, warmup_steps=warmup_steps,
461
+ t_total=t_total)
462
+
463
+ optimizers.append(optimizer)
464
+ schedulers.append(scheduler_obj)
465
+
466
+ if fp16:
467
+ try:
468
+ from apex import amp
469
+ except ImportError:
470
+ raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
471
+
472
+ for train_idx in range(len(loss_models)):
473
+ model, optimizer = amp.initialize(loss_models[train_idx], optimizers[train_idx],
474
+ opt_level=fp16_opt_level)
475
+ loss_models[train_idx] = model
476
+ optimizers[train_idx] = optimizer
477
+
478
+ data_iterators = [iter(dataloader) for dataloader in dataloaders]
479
+ num_train_objectives = len(train_objectives)
480
+
481
+ for epoch in range(epochs):
482
+ #logging.info('Epoch {}'.format(epoch))
483
+ epoch_loss = 0
484
+ epoch_steps = 0
485
+ for loss_model in loss_models:
486
+ loss_model.zero_grad()
487
+ loss_model.train()
488
+
489
+ step_iterator = tqdm(range(steps_per_epoch), desc='loss: -')
490
+ for step in step_iterator:
491
+ for train_idx in range(num_train_objectives):
492
+ loss_model = loss_models[train_idx]
493
+ optimizer = optimizers[train_idx]
494
+ scheduler = schedulers[train_idx]
495
+ data_iterator = data_iterators[train_idx]
496
+
497
+ try:
498
+ data = next(data_iterator)
499
+ except StopIteration:
500
+ # logging.info("Restart data_iterator")
501
+ data_iterator = iter(dataloaders[train_idx])
502
+ data_iterators[train_idx] = data_iterator
503
+ data = next(data_iterator)
504
+
505
+ features, labels = batch_to_device(data, self.device)
506
+ loss_value = loss_model(features, labels)
507
+ loss_value = loss_value / acc_steps
508
+
509
+ if fp16:
510
+ with amp.scale_loss(loss_value, optimizer) as scaled_loss:
511
+ scaled_loss.backward()
512
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
513
+ else:
514
+ loss_value.backward()
515
+ torch.nn.utils.clip_grad_norm_(loss_model.parameters(), max_grad_norm)
516
+
517
+ if (step + 1) % acc_steps == 0:
518
+ optimizer.step()
519
+ scheduler.step()
520
+ optimizer.zero_grad()
521
+
522
+ self.total_steps += 1
523
+
524
+ if (step + 1) % acc_steps == 0:
525
+ epoch_steps += 1
526
+ epoch_loss += loss_value.item()
527
+ step_iterator.set_description('loss: {} - acc steps: {}'.format((epoch_loss / epoch_steps),
528
+ (self.total_steps / acc_steps)))
529
+
530
+ if evaluation_steps > 0 and self.total_steps > 0 and (self.total_steps / acc_steps) % evaluation_steps == 0:
531
+ self._eval_during_training(evaluator, output_path, save_best_model, epoch, epoch_steps)
532
+ for loss_model in loss_models:
533
+ loss_model.zero_grad()
534
+ loss_model.train()
535
+
536
+ self._eval_during_training(evaluator, output_path, save_best_model, epoch, -1)
537
+
538
+ def _get_scheduler(self, optimizer, scheduler: str, warmup_steps: int, t_total: int):
539
+ """
540
+ Returns the correct learning rate scheduler
541
+ """
542
+ scheduler = scheduler.lower()
543
+ if scheduler == 'constantlr':
544
+ return transformers.get_constant_schedule(optimizer)
545
+ elif scheduler == 'warmupconstant':
546
+ return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
547
+ elif scheduler == 'warmuplinear':
548
+ return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
549
+ elif scheduler == 'warmupcosine':
550
+ return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
551
+ elif scheduler == 'warmupcosinewithhardrestarts':
552
+ return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
553
+ else:
554
+ raise ValueError("Unknown scheduler {}".format(scheduler))
555
+
556
+
557
+ class SentencesDataset(Dataset):
558
+ def __init__(self, examples: List[InputExample], model: SentenceTransformer, show_progress_bar: bool = None):
559
+ self.examples = examples
560
+ self.model = model
561
+
562
+ def __getitem__(self, item):
563
+ tokenized_texts = [model.tokenize(text) for text in self.examples[item].texts]
564
+ return tokenized_texts, torch.tensor(self.examples[item].label, dtype=torch.float)
565
+
566
+ def __len__(self):
567
+ return len(self.examples)
568
+
569
+
570
+ class MultipleNegativesRankingLossANN(nn.Module):
571
+ def __init__(self, model: SentenceTransformer, negative_samples=4):
572
+ super(MultipleNegativesRankingLossANN, self).__init__()
573
+ self.model = model
574
+ self.negative_samples = negative_samples
575
+
576
+ def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
577
+ embeddings = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
578
+ return self.multiple_negatives_ranking_loss(embeddings)
579
+
580
+ def multiple_negatives_ranking_loss(self, embeddings: List[torch.Tensor]):
581
+ positive_loss = torch.mean(torch.sum(embeddings[0] * embeddings[1], dim=-1))
582
+
583
+ negative_loss = torch.sum(embeddings[0] * embeddings[1], dim=-1)
584
+ for i in range(2, self.negative_samples + 2):
585
+ negative_loss = torch.cat((negative_loss, torch.sum(embeddings[0] * embeddings[i], dim=-1)), dim=-1)
586
+ negative_loss = torch.mean(torch.logsumexp(negative_loss, dim=-1))
587
+
588
+ return -positive_loss + negative_loss
589
+
590
+
591
+ class MultipleNegativesRankingLoss(nn.Module):
592
+ def __init__(self, model: SentenceTransformer):
593
+ super(MultipleNegativesRankingLoss, self).__init__()
594
+ self.model = model
595
+
596
+ def forward(self, sentence_features: Iterable[Dict[str, torch.Tensor]], labels: torch.Tensor):
597
+ reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
598
+
599
+ reps_a, reps_b = reps
600
+ return self.multiple_negatives_ranking_loss(reps_a, reps_b)
601
+
602
+ def multiple_negatives_ranking_loss(self, embeddings_a: torch.Tensor, embeddings_b: torch.Tensor):
603
+ scores = torch.matmul(embeddings_a, embeddings_b.t())
604
+ diagonal_mean = torch.mean(torch.diag(scores))
605
+ mean_log_row_sum_exp = torch.mean(torch.logsumexp(scores, dim=1))
606
+ return -diagonal_mean + mean_log_row_sum_exp
607
+
608
+
609
+ class Pooling(nn.Module):
610
+ """Performs pooling (max or mean) on the token embeddings.
611
+
612
+ Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model.
613
+ You can concatenate multiple poolings together.
614
+ """
615
+ def __init__(self,
616
+ word_embedding_dimension: int,
617
+ pooling_mode_cls_token: bool = False,
618
+ pooling_mode_max_tokens: bool = False,
619
+ pooling_mode_mean_tokens: bool = True,
620
+ pooling_mode_mean_sqrt_len_tokens: bool = False,
621
+ ):
622
+ super(Pooling, self).__init__()
623
+
624
+ self.config_keys = ['word_embedding_dimension', 'pooling_mode_cls_token', 'pooling_mode_mean_tokens', 'pooling_mode_max_tokens', 'pooling_mode_mean_sqrt_len_tokens']
625
+
626
+ self.word_embedding_dimension = word_embedding_dimension
627
+ self.pooling_mode_cls_token = pooling_mode_cls_token
628
+ self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
629
+ self.pooling_mode_max_tokens = pooling_mode_max_tokens
630
+ self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
631
+
632
+ pooling_mode_multiplier = sum([pooling_mode_cls_token, pooling_mode_max_tokens, pooling_mode_mean_tokens, pooling_mode_mean_sqrt_len_tokens])
633
+ self.pooling_output_dimension = (pooling_mode_multiplier * word_embedding_dimension)
634
+
635
+ def forward(self, features: Dict[str, torch.Tensor]):
636
+ token_embeddings = features['token_embeddings']
637
+ cls_token = features['cls_token_embeddings']
638
+ attention_mask = features['attention_mask']
639
+
640
+ ## Pooling strategy
641
+ output_vectors = []
642
+ if self.pooling_mode_cls_token:
643
+ output_vectors.append(cls_token)
644
+ if self.pooling_mode_max_tokens:
645
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
646
+ token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
647
+ max_over_time = torch.max(token_embeddings, 1)[0]
648
+ output_vectors.append(max_over_time)
649
+ if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
650
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
651
+ sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
652
+
653
+ #If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
654
+ if 'token_weights_sum' in features:
655
+ sum_mask = features['token_weights_sum'].unsqueeze(-1).expand(sum_embeddings.size())
656
+ else:
657
+ sum_mask = input_mask_expanded.sum(1)
658
+
659
+ sum_mask = torch.clamp(sum_mask, min=1e-9)
660
+
661
+ if self.pooling_mode_mean_tokens:
662
+ output_vectors.append(sum_embeddings / sum_mask)
663
+ if self.pooling_mode_mean_sqrt_len_tokens:
664
+ output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
665
+
666
+ output_vector = torch.cat(output_vectors, 1)
667
+ features.update({'sentence_embedding': output_vector})
668
+ return features
669
+
670
+ def get_sentence_embedding_dimension(self):
671
+ return self.pooling_output_dimension
672
+
673
+ def get_config_dict(self):
674
+ return {key: self.__dict__[key] for key in self.config_keys}
675
+
676
+ def save(self, output_path):
677
+ with open(os.path.join(output_path, 'config.json'), 'w') as fOut:
678
+ json.dump(self.get_config_dict(), fOut, indent=2)
679
+
680
+ @staticmethod
681
+ def load(input_path):
682
+ with open(os.path.join(input_path, 'config.json')) as fIn:
683
+ config = json.load(fIn)
684
+
685
+ return Pooling(**config)
686
+
687
+
688
+ class RankingEvaluator:
689
+ def __init__(self, dataloader: DataLoader, random_paragraphs: DataLoader = None,
690
+ name: str = '', show_progress_bar: bool = None):
691
+ """
692
+ Constructs an evaluator based for the dataset
693
+
694
+ The labels need to indicate the similarity between the sentences.
695
+
696
+ :param dataloader:
697
+ the data for the evaluation
698
+ :param main_similarity:
699
+ the similarity metric that will be used for the returned score
700
+ """
701
+ self.dataloader = dataloader
702
+ self.name = name
703
+ if name:
704
+ name = "_" + name
705
+
706
+ if show_progress_bar is None:
707
+ show_progress_bar = (
708
+ logging.getLogger().getEffectiveLevel() == logging.INFO or logging.getLogger().getEffectiveLevel() == logging.DEBUG)
709
+ self.show_progress_bar = show_progress_bar
710
+
711
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
712
+
713
+ self.random_paragraphs = random_paragraphs
714
+
715
+ def __call__(self, model: 'SequentialSentenceEmbedder', output_path: str = None, epoch: int = -1,
716
+ steps: int = -1) -> float:
717
+ model.eval()
718
+ embeddings1 = []
719
+ embeddings2 = []
720
+
721
+ if epoch != -1:
722
+ if steps == -1:
723
+ out_txt = f" after epoch {epoch}:"
724
+ else:
725
+ out_txt = f" in epoch {epoch} after {steps} steps:"
726
+ else:
727
+ out_txt = ":"
728
+
729
+ logging.info("Evaluation" + out_txt)
730
+
731
+ logging.info('Calculating sentence embeddings...')
732
+
733
+ self.dataloader.collate_fn = model.smart_batching_collate
734
+ iterator = self.dataloader
735
+ evidence_features = []
736
+ for step, batch in enumerate(tqdm(iterator, desc='batch')):
737
+ features, label_ids = batch_to_device(batch, self.device)
738
+ with torch.no_grad():
739
+ emb1, emb2 = [model(sent_features)['sentence_embedding'].to("cpu").numpy() for sent_features in
740
+ features]
741
+ embeddings1.extend(emb1)
742
+ embeddings2.extend(emb2)
743
+
744
+ input_ids_2 = features[1]['input_ids'].to("cpu").numpy()
745
+ for f in input_ids_2:
746
+ evidence_features.append(hashlib.sha1(str(f).encode('utf-8')).hexdigest())
747
+
748
+ total_examples = len(embeddings1)
749
+
750
+ logging.info('Building index...')
751
+
752
+ vindex = VectorIndex(len(embeddings1[0]))
753
+ for v in embeddings2:
754
+ vindex.add(v)
755
+
756
+ if self.random_paragraphs is not None:
757
+ logging.info('Calculating random wikipedia paragraph embeddings...')
758
+
759
+ self.random_paragraphs.collate_fn = model.smart_batching_collate
760
+ iterator = self.random_paragraphs
761
+ for step, batch in enumerate(iterator):
762
+ if step % 10 == 0:
763
+ logging.info('Batch {}/{}'.format(step, len(iterator)))
764
+ features, label_ids = batch_to_device(batch, self.device)
765
+ with torch.no_grad():
766
+ embeddings = model(features[0])['sentence_embedding'].to("cpu").numpy()
767
+ for emb in embeddings:
768
+ vindex.add(emb)
769
+
770
+ vindex.build()
771
+
772
+ logging.info('Ranking evaluation...')
773
+
774
+ mrr = 1e-8
775
+ recall = {1: 0, 5: 0, 10: 0, 20: 0, 100: 0}
776
+
777
+ all_results, _ = vindex.search(embeddings1, k=100, probes=1024)
778
+
779
+ for i in range(total_examples):
780
+ results = all_results[i]
781
+ rank = 1
782
+ found = False
783
+ for r in results:
784
+ if r < len(evidence_features) and evidence_features[r] == evidence_features[i]:
785
+ mrr += 1 / rank
786
+ found = True
787
+ break
788
+ rank += 1
789
+
790
+ for topk, count in recall.items():
791
+ if rank <= topk and found:
792
+ recall[topk] += 1
793
+ mrr /= total_examples
794
+ for topk, count in recall.items():
795
+ recall[topk] /= total_examples
796
+ logging.info('recall@{} : {}'.format(topk, recall[topk]))
797
+
798
+ logging.info('mrr@100 : {}'.format(mrr))
799
+
800
+ if output_path is not None:
801
+ f = open(output_path + '/stats.csv', 'a+')
802
+ f.write('{},{},{},{},{},{}\n'.format(recall[1], recall[5], recall[10], recall[20], recall[100], mrr))
803
+ f.close()
804
+
805
+ return mrr
806
+
807
+
808
+ if __name__ == "__main__":
809
+ logging.basicConfig(format='%(asctime)s - %(message)s',
810
+ datefmt='%Y-%m-%d %H:%M:%S',
811
+ level=logging.CRITICAL,
812
+ handlers=[LoggingHandler()])
813
+
814
+ # Read the dataset
815
+ model_name = 'distilroberta-base'
816
+ batch_size = 384
817
+ model_save_path = 'weights/encoder/qrbert-multitask-distil'
818
+ num_epochs = 1000
819
+
820
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
821
+
822
+ # Use BERT for mapping tokens to embeddings
823
+ word_embedding_model = BERT(model_name, max_seq_length=256, do_lower_case=False)
824
+
825
+ # Apply mean pooling to get one fixed sized sentence vector
826
+ pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension(),
827
+ pooling_mode_mean_tokens=True,
828
+ pooling_mode_cls_token=False,
829
+ pooling_mode_max_tokens=False)
830
+
831
+ if torch.cuda.device_count() > 1:
832
+ print("Let's use", torch.cuda.device_count(), "GPUs!")
833
+
834
+ word_embedding_model = DataParallel(word_embedding_model)
835
+ word_embedding_model.to(device)
836
+
837
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model], parallel=True)
838
+
839
+ #model = SentenceTransformer('models/weights/encoder/sbert-nli-fnli-qar2-768', parallel=True)
840
+
841
+ logging.info("Reading dev dataset")
842
+ dev_data = SentencesDataset(get_retrieval_examples(filename='../data/retrieval/dev.jsonl',
843
+ no_statements=False,
844
+ ), model=model)
845
+ dev_dataloader = DataLoader(dev_data, shuffle=True, batch_size=1024)
846
+ evaluator = RankingEvaluator(dev_dataloader)
847
+ #model.best_score = model.evaluate(evaluator)
848
+
849
+ train_data = SentencesDataset(get_retrieval_examples(filename='../data/retrieval/train.jsonl',
850
+ no_statements=False),
851
+ model=model)
852
+ train_dataloader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
853
+ train_loss = MultipleNegativesRankingLoss(model=model)
854
+
855
+ for i in range(num_epochs):
856
+ logging.info("Epoch {}".format(i))
857
+ # Train the model
858
+ model.fit(train_objectives=[(train_dataloader, train_loss)],
859
+ evaluator=evaluator,
860
+ epochs=1,
861
+ acc_steps=1,
862
+ evaluation_steps=1000,
863
+ warmup_steps=1000,
864
+ output_path=model_save_path,
865
+ optimizer_params={'lr': 1e-6, 'eps': 1e-6, 'correct_bias': False}
866
+ )
867
+ torch.cuda.empty_cache()
models/tokenization.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_segments(sentence):
5
+ sentence_segments = []
6
+ temp = []
7
+ i = 0
8
+ for token in sentence.split(" "):
9
+ temp.append(i)
10
+ if token == "[SEP]":
11
+ i += 1
12
+ sentence_segments.append(temp)
13
+ return sentence_segments
14
+
15
+
16
+ def tokenize(text, max_length, tokenizer, second_text=None):
17
+ if second_text is None:
18
+ sentence = "[CLS] " + " ".join(tokenizer.tokenize(text.replace('[^\w\s]+|\n', ''))[:max_length-2]) + " [SEP]"
19
+ else:
20
+ text = tokenizer.tokenize(text.replace('[^\w\s]+|\n', ''))
21
+ second_text = tokenizer.tokenize(second_text.replace('[^\w\s]+|\n', ''))
22
+ while len(text) + len(second_text) > max_length - 3:
23
+ if len(text) > len(second_text):
24
+ text.pop()
25
+ else:
26
+ second_text.pop()
27
+ sentence = "[CLS] " + " ".join(text) + " [SEP] " + " ".join(second_text) + " [SEP]"
28
+
29
+ # generate masks
30
+ # bert requires a mask for the words which are padded.
31
+ # Say for example, maxlen is 100, sentence size is 90. then, [1]*90 + [0]*[100-90]
32
+ sentence_mask = [1] * len(sentence.split(" ")) + [0] * (max_length - len(sentence.split(" ")))
33
+
34
+ # generate input ids
35
+ # if less than max length provided then the words are padded
36
+ if len(sentence.split(" ")) != max_length:
37
+ sentence_padded = sentence + " [PAD]" * (max_length - len(sentence.split(" ")))
38
+ else:
39
+ sentence_padded = sentence
40
+
41
+ sentence_converted = tokenizer.convert_tokens_to_ids(sentence_padded.split(" "))
42
+
43
+ # generate segments
44
+ # for each separation [SEP], a new segment is converted
45
+ sentence_segment = get_segments(sentence_padded)
46
+
47
+ # convert list into tensor integer arrays and return it
48
+ # return sentences_converted,sentences_segment, sentences_mask
49
+ """
50
+ return [tf.cast(sentence_converted, tf.int32),
51
+ tf.cast(sentence_segment, tf.int32),
52
+ tf.cast(sentence_mask, tf.int32)]
53
+ """
54
+ return [np.asarray(sentence_converted, dtype=np.int32).squeeze(),
55
+ np.asarray(sentence_segment, dtype=np.int32).squeeze(),
56
+ np.asarray(sentence_mask, dtype=np.int32).squeeze()]
models/vector_index.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import pickle
4
+ import h5py
5
+ import faiss
6
+
7
+ import numpy as np
8
+ import gpustat
9
+
10
+ import pdb
11
+
12
+ # see http://ulrichpaquet.com/Papers/SpeedUp.pdf theorem 5
13
+
14
+ def get_phi(xb):
15
+ return (xb ** 2).sum(1).max()
16
+
17
+
18
+ def augment_xb(xb, phi=None):
19
+ norms = (xb ** 2).sum(1)
20
+ if phi is None:
21
+ phi = norms.max()
22
+ extracol = np.sqrt(phi - norms)
23
+ return np.hstack((xb, extracol.reshape(-1, 1)))
24
+
25
+
26
+ def augment_xq(xq):
27
+ extracol = np.zeros(len(xq), dtype='float32')
28
+ return np.hstack((xq, extracol.reshape(-1, 1)))
29
+
30
+
31
+ class VectorIndex:
32
+ def __init__(self, d):
33
+ self.d = d
34
+ self.vectors = []
35
+ self.index = None
36
+
37
+ def add(self, v):
38
+ self.vectors.append(v)
39
+
40
+ def add_vectors(self, vs):
41
+ logging.info('Adding vectors to index...')
42
+ self.index.add(vs)
43
+
44
+
45
+ def build(self, use_gpu=False):
46
+ self.vectors = np.array(self.vectors) # OOM at this step if building too many vectors
47
+
48
+ faiss.normalize_L2(self.vectors)
49
+
50
+ #self.vectors = augment_xq(self.vectors)
51
+
52
+ logging.info('Indexing {} vectors'.format(self.vectors.shape[0]))
53
+
54
+ if self.vectors.shape[0] > 50000:
55
+ num_centroids = 8 * int(math.sqrt(math.pow(2, int(math.log(self.vectors.shape[0], 2)))))
56
+
57
+ logging.info('Using {} centroids'.format(num_centroids))
58
+
59
+ self.index = faiss.index_factory(self.d, "IVF{}_HNSW32,Flat".format(num_centroids))
60
+
61
+ ngpu = faiss.get_num_gpus()
62
+ if ngpu > 0 and use_gpu:
63
+ logging.info('Using {} GPUs'.format(ngpu))
64
+
65
+ index_ivf = faiss.extract_index_ivf(self.index)
66
+ gpustat.print_gpustat()
67
+ clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(self.d))
68
+ index_ivf.clustering_index = clustering_index
69
+
70
+ logging.info('Training index...')
71
+
72
+ self.index.train(self.vectors)
73
+ else:
74
+ self.index = faiss.IndexFlatL2(self.d)
75
+ if faiss.get_num_gpus() > 0 and use_gpu:
76
+ gpustat.print_gpustat()
77
+ self.index = faiss.index_cpu_to_all_gpus(self.index)
78
+
79
+ def load(self, path):
80
+ self.index = faiss.read_index(path)
81
+
82
+ def save(self, path):
83
+ gpustat.print_gpustat()
84
+ faiss.write_index(faiss.index_gpu_to_cpu(self.index), path)
85
+
86
+ def save_vectors(self, path):
87
+ #pickle.dump(self.vectors, open(path, 'wb'), protocol=4)
88
+ f = h5py.File(path, 'w')
89
+ dset = f.create_dataset('data', data=self.vectors)
90
+ f.close()
91
+
92
+ def search(self, vectors, k=1, probes=1):
93
+ if not isinstance(vectors, np.ndarray):
94
+ vectors = np.array(vectors)
95
+ #faiss.normalize_L2(vectors)
96
+ try:
97
+ self.index.nprobe = probes
98
+ except:
99
+ pass
100
+ distances, ids = self.index.search(vectors, k)
101
+ similarities = [(2-d)/2 for d in distances]
102
+ return ids, similarities
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm==4.41.0
2
+ transformers==3.0.2
3
+ faiss_gpu==1.6.3
4
+ syllables==0.1.0
5
+ scipy==1.3.1
6
+ nltk==3.2.5
7
+ lxml==4.6.3
8
+ requests==2.22.0
9
+ Flask_Cors==3.0.8
10
+ gpustat==0.6.0
11
+ h5py==2.10.0
12
+ packaging==20.3
13
+ pandas==1.0.3
14
+ spacy==3.0.3
15
+ sentence-transformers==0.3.0
16
+ numpy==1.17.4
17
+ matplotlib==3.3.1
18
+ filelock==3.0.12
19
+ boilerpipe3==1.1
20
+ Unidecode==1.1.1
21
+ Flask==2.0.0
22
+ apex
23
+ beautifulsoup4==4.9.3
24
+ python_dateutil==2.8.1
25
+ scikit_learn==0.24.1
26
+ tantivy==0.13.2
27
+ huggingface_hub==0.16.4
28
+ torch==1.6.0+cu101
29
+ torchvision==0.7.0+cu101
30
+ gradio==3.34.0
utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html.parser import HTMLParser
2
+
3
+
4
+ class WebParser(HTMLParser):
5
+ """
6
+ A class for converting the tagged html to formats that can be used by a ML model
7
+ """
8
+ def __init__(self):
9
+ super().__init__()
10
+ self.block_tags = {
11
+ 'div', 'p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'
12
+ }
13
+ self.inline_tags = {
14
+ '', 'a', 'b', 'main', 'span', 'em', 'strong', 'br'
15
+ }
16
+ self.allowed_tags = {'div', 'p', '', 'a', 'b', 'main', 'span', 'em', 'strong', 'br'}
17
+
18
+ self.opened_tags = []
19
+ self.block_content = ''
20
+ self.blocks = []
21
+
22
+ def get_last_opened_tag(self):
23
+ """
24
+ Gets the last visited tag
25
+ :return:
26
+ """
27
+ if len(self.opened_tags) > 0:
28
+ return self.opened_tags[len(self.opened_tags) - 1]
29
+ return ''
30
+
31
+ def error(self, message):
32
+ pass
33
+
34
+ def handle_starttag(self, tag, attrs):
35
+ """
36
+ Handles the start tag of an HTML node in the tree
37
+ :param tag: the HTML tag
38
+ :param attrs: the tag attributes
39
+ :return:
40
+ """
41
+ self.opened_tags.append(tag)
42
+
43
+ def handle_endtag(self, tag):
44
+ """
45
+ Handles the end tag of an HTML node in the tree
46
+ :param tag: the HTML tag
47
+ :return:
48
+ """
49
+ if tag in self.block_tags:
50
+ self.block_content = self.block_content.strip()
51
+ if len(self.block_content) > 0:
52
+ #if not self.block_content.endswith('.'): self.block_content += '.'
53
+ self.blocks.append(self.block_content)
54
+ self.block_content = ''
55
+ if len(self.opened_tags) > 0:
56
+ self.opened_tags.pop()
57
+
58
+ def handle_data(self, data):
59
+ """
60
+ Handles a text HTML node in the tree
61
+ :param data: the text node
62
+ :return:
63
+ """
64
+ last_opened_tag = self.get_last_opened_tag()
65
+ if last_opened_tag in self.allowed_tags:
66
+ data = data.replace(' ', ' ').strip()
67
+ if data != '':
68
+ self.block_content += data + ' '
69
+
70
+ def get_blocks(self):
71
+ return self.blocks
web_search.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from pprint import pprint
4
+ from urllib.request import Request, urlopen
5
+ from html.parser import HTMLParser
6
+ from bs4 import BeautifulSoup
7
+ from urllib.parse import quote_plus
8
+
9
+
10
+ def bing_search(query: str, pages_number=1) -> list:
11
+ """
12
+ Gets web results from Bing
13
+ :param query: query to search
14
+ :param pages_number: number of search pages to scrape
15
+ :return: a list of links in ranked order
16
+ """
17
+ urls = []
18
+ for page in range(pages_number):
19
+ first = page * 10 + 1
20
+ address = "https://www.bing.com/search?q=" + quote_plus(query) + '&first=' + str(first)
21
+ data = get_html(address)
22
+ soup = BeautifulSoup(data, 'lxml')
23
+ links = soup.findAll('li', {'class': 'b_algo'})
24
+ urls.extend([link.find('h2').find('a')['href'] for link in links])
25
+
26
+ return urls
27
+
28
+
29
+ def duckduckgo_search(query: str, pages=1):
30
+ urls = []
31
+ start_index = 0
32
+ for page in range(pages):
33
+ address = "https://duckduckgo.com/html/?kl=en-us&q={}&s={}".format(quote_plus(query), start_index)
34
+ data = get_html(address)
35
+ soup = BeautifulSoup(data, 'lxml')
36
+ links = soup.findAll('a', {'class': 'result__a'})
37
+ urls.extend([link['href'] for link in links])
38
+ start_index = len(urls)
39
+
40
+ return urls
41
+
42
+
43
+ def news_search(query: str, pages=1):
44
+ urls = []
45
+ for page in range(pages):
46
+ api_url = f'https://newslookup.com/results?l=2&q={quote_plus(query)}&dp=&mt=-1&mkt=0&mtx=0&mktx=0&s=&groupby=no&cat=-1&from=&fmt=&tp=720&ps=50&ovs=&page={page}'
47
+ data = get_html(api_url)
48
+ soup = BeautifulSoup(data, 'lxml')
49
+ links = soup.findAll('a', {'class': 'title'})
50
+ urls.extend([link['href'] for link in links])
51
+ return urls
52
+
53
+
54
+ def get_html(url: str) -> str:
55
+ """
56
+ Downloads the html source code of a webpage
57
+ :param url:
58
+ :return: html source code
59
+ """
60
+ try:
61
+ custom_user_agent = "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:47.0) Gecko/20100101 Firefox/47.0"
62
+ req = Request(url, headers={"User-Agent": custom_user_agent})
63
+ page = urlopen(req, timeout=3)
64
+ return str(page.read())
65
+ except:
66
+ return ''
67
+
68
+
69
+ class WebParser(HTMLParser):
70
+ """
71
+ A class for converting the tagged html to formats that can be used by a ML model
72
+ """
73
+
74
+ def __init__(self):
75
+ super().__init__()
76
+ self.block_tags = {
77
+ 'div', 'p'
78
+ }
79
+ self.inline_tags = {
80
+ '', 'a', 'b', 'tr', 'main', 'span', 'time', 'td',
81
+ 'sup', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'em', 'strong', 'br'
82
+ }
83
+ self.allowed_tags = self.block_tags.union(self.inline_tags)
84
+ self.opened_tags = []
85
+ self.block_content = ''
86
+ self.blocks = []
87
+
88
+ def get_last_opened_tag(self):
89
+ """
90
+ Gets the last visited tag
91
+ :return:
92
+ """
93
+ if len(self.opened_tags) > 0:
94
+ return self.opened_tags[len(self.opened_tags) - 1]
95
+ return ''
96
+
97
+ def error(self, message):
98
+ pass
99
+
100
+ def handle_starttag(self, tag, attrs):
101
+ """
102
+ Handles the start tag of an HTML node in the tree
103
+ :param tag: the HTML tag
104
+ :param attrs: the tag attributes
105
+ :return:
106
+ """
107
+ self.opened_tags.append(tag)
108
+ if tag in self.block_tags:
109
+ self.block_content = self.block_content.strip()
110
+ if len(self.block_content) > 0:
111
+ if not self.block_content.endswith('.'):
112
+ self.block_content += '.'
113
+ self.block_content = self.block_content.replace('\\n', ' ').replace('\\r', ' ')
114
+ self.block_content = re.sub("\s\s+", " ", self.block_content)
115
+ self.blocks.append(self.block_content)
116
+ self.block_content = ''
117
+
118
+ def handle_endtag(self, tag):
119
+ """
120
+ Handles the end tag of an HTML node in the tree
121
+ :param tag: the HTML tag
122
+ :return:
123
+ """
124
+ if len(self.opened_tags) > 0:
125
+ self.opened_tags.pop()
126
+
127
+ def handle_data(self, data):
128
+ """
129
+ Handles a text HTML node in the tree
130
+ :param data: the text node
131
+ :return:
132
+ """
133
+ last_opened_tag = self.get_last_opened_tag()
134
+ if last_opened_tag in self.allowed_tags:
135
+ data = data.replace(' ', ' ').strip()
136
+ if data != '':
137
+ self.block_content += data + ' '
138
+
139
+ def get_text(self):
140
+ return "\n\n".join(self.blocks)