anabmaulana commited on
Commit
071a1df
·
1 Parent(s): 8eb533e
Files changed (3) hide show
  1. models/vector_index.py +0 -1
  2. quin_web.py +278 -0
  3. requirements.txt +0 -1
models/vector_index.py CHANGED
@@ -1,7 +1,6 @@
1
  import logging
2
  import math
3
  import pickle
4
- import h5py
5
  import faiss
6
 
7
  import numpy as np
 
1
  import logging
2
  import math
3
  import pickle
 
4
  import faiss
5
 
6
  import numpy as np
quin_web.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import json
4
+ import logging
5
+ import nltk
6
+ import numpy
7
+ import os
8
+ import re
9
+ import requests
10
+ import torch
11
+ import urllib
12
+ import urllib.parse as urlparse
13
+
14
+ from bs4 import BeautifulSoup
15
+ from flask import request, Flask, jsonify
16
+ from flask_cors import CORS
17
+ from huggingface_hub import hf_hub_download, snapshot_download, login
18
+ from models.nli import NLI
19
+ from models.qa_ranker import PassageRanker
20
+ from models.text_encoder import SentenceTransformer
21
+ from multiprocessing.pool import ThreadPool
22
+ from scipy import spatial
23
+ from scipy.special import softmax
24
+ from time import sleep
25
+ from urllib.parse import parse_qs
26
+ from web_search import get_html, WebParser, duckduckgo_search
27
+
28
+
29
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
30
+ API_URLS = ['https://letscheck.nus.edu.sg/quin/']
31
+ logging.getLogger().setLevel(logging.INFO)
32
+
33
+
34
+ def is_question(query):
35
+ if re.match(r'^(who|when|what|why|which|whose|is|are|was|were|do|does|did|how)', query) or query.endswith('?'):
36
+ return True
37
+ pos_tags = nltk.pos_tag(nltk.word_tokenize(query))
38
+ for tag in pos_tags:
39
+ if tag[1].startswith('VB'):
40
+ return False
41
+ return True
42
+
43
+
44
+ def clear_cache():
45
+ for url in API_URLS:
46
+ print(f"Clearing cache from {url}")
47
+ r = requests.post(f'{url}/clear_cache?secret=123')
48
+ print(r.status_code)
49
+ print(r)
50
+
51
+
52
+ class Quin:
53
+ def __init__(self):
54
+ nltk.download('punkt')
55
+ self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
56
+ self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp'])
57
+ self.app = Flask(__name__)
58
+ CORS(self.app)
59
+
60
+ snapshot_download(repo_id="anabmaulana/qrbert-multitask", local_dir="cache/qrbert-multitask")
61
+ hf_hub_download(repo_id="anabmaulana/quin-passage-ranker", filename="passage_ranker.state_dict", local_dir="cache/quin-passage-ranker")
62
+ hf_hub_download(repo_id="anabmaulana/quin-nli", filename="nli.state_dict", local_dir="cache/quin-nli")
63
+
64
+ # torch.cuda.is_available = lambda: False
65
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
66
+
67
+ self.text_embedding_model = SentenceTransformer('cache/qrbert-multitask',
68
+ device=device,
69
+ parallel=True)
70
+ self.passage_ranking_model = PassageRanker(
71
+ model_path='cache/quin-passage-ranker/passage_ranker.state_dict',
72
+ gpu=torch.cuda.is_available(),
73
+ batch_size=16)
74
+ self.passage_ranking_model.eval()
75
+ self.nli_model = NLI('cache/quin-nli/nli.state_dict',
76
+ batch_size=64,
77
+ device=device,
78
+ parallel=True)
79
+ self.nli_model.eval()
80
+ logging.info('Initialized!')
81
+
82
+ def extract_snippets(self, text, sentences_per_snippet=4):
83
+ sentences = self.sent_tokenizer.tokenize(text)
84
+ snippets = []
85
+ i = 0
86
+ last_index = 0
87
+ while i < len(sentences):
88
+ snippet = ' '.join(sentences[i:i + sentences_per_snippet])
89
+ if len(snippet.split(' ')) > 4:
90
+ snippets.append(snippet)
91
+ last_index = i + sentences_per_snippet
92
+ i += sentences_per_snippet
93
+ if last_index < len(sentences):
94
+ snippet = ' '.join(sentences[last_index:])
95
+ if len(snippet.split(' ')) > 4:
96
+ snippets.append(snippet)
97
+ return snippets
98
+
99
+ def search_web_evidence(self, query, limit=10):
100
+ logging.info('searching the web...')
101
+ urls = duckduckgo_search(query, pages=1)[:limit]
102
+ logging.info('downloading {} web pages...'.format(len(urls)))
103
+ search_results = []
104
+ # Move this out of the function, since it is used a few times
105
+ query_is_question = is_question(query)
106
+
107
+ # Local methods for multithreading
108
+ def download(url):
109
+ nonlocal search_results
110
+ data = get_html(url)
111
+ soup = BeautifulSoup(data, features='lxml')
112
+ title = soup.title.string
113
+ w = WebParser()
114
+ w.feed(data)
115
+ new_snippets = sum([self.extract_snippets(b) for b in w.blocks if b.count(' ') > 20], [])
116
+ new_snippets = [{'snippet': snippet, 'url': url, 'title': title} for snippet in new_snippets]
117
+ search_results += new_snippets
118
+
119
+ def timeout_download(arg):
120
+ pool = ThreadPool(1)
121
+ try:
122
+ pool.apply_async(download, [arg]).get(timeout=2)
123
+ except:
124
+ pass
125
+ pool.close()
126
+ pool.join()
127
+
128
+ p = ThreadPool(32)
129
+ p.map(timeout_download, urls)
130
+ p.close()
131
+ p.join()
132
+
133
+ if query_is_question:
134
+ logging.info('re-ranking...')
135
+ snippets = [s['snippet'] for s in search_results]
136
+ qa_pairs = [(query, snippet) for snippet in snippets]
137
+ _, probs = self.passage_ranking_model(qa_pairs)
138
+ probs = [softmax(p)[1] for p in probs]
139
+ filtered_results = []
140
+ for i in range(len(search_results)):
141
+ if probs[i] > 0.35:
142
+ search_results[i]['score'] = str(probs[i])
143
+ filtered_results.append(search_results[i])
144
+ search_results = filtered_results
145
+ search_results = sorted(search_results, key=lambda x: float(x['score']), reverse=True)
146
+
147
+ # Split into a different function below
148
+ # highlight_results(self, search_results)
149
+ # highlight most relevant sentences
150
+ logging.info('highlighting...')
151
+ results_sentences = []
152
+ sentences_texts = []
153
+ sentences_vectors = {}
154
+ for i, r in enumerate(search_results):
155
+ sentences = self.sent_tokenizer.tokenize(r['snippet'])
156
+ sentences = [s for s in sentences if len(s.split(' ')) > 4]
157
+ sentences_texts.extend(sentences)
158
+ results_sentences.append(sentences)
159
+
160
+ vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128)
161
+ for i, v in enumerate(vectors):
162
+ sentences_vectors[sentences_texts[i]] = v
163
+
164
+ query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0]
165
+ for i, sentences in enumerate(results_sentences):
166
+ best_sentences = set()
167
+ evidence_sentences = []
168
+ for sentence in sentences:
169
+ sentence_vector = sentences_vectors[sentence]
170
+ score = 1 - spatial.distance.cosine(query_vector, sentence_vector)
171
+ if score > 0.91:
172
+ best_sentences.add(sentence)
173
+ evidence_sentences.append(sentence)
174
+ if len(evidence_sentences) > 0:
175
+ search_results[i]['evidence'] = ' '.join(evidence_sentences)
176
+ search_results[i]['snippet'] = \
177
+ ' '.join([s if s not in best_sentences else '<b>{}</b>'.format(s) for s in sentences])
178
+
179
+ search_results = [s for s in search_results if 'evidence' in s]
180
+
181
+ # Split into a different function below
182
+ # entailment_classification(self, search_results)
183
+ if not query_is_question:
184
+ logging.info('entailment classification...')
185
+ es_pairs = []
186
+ for result in search_results:
187
+ evidence = result['evidence']
188
+ es_pairs.append((evidence, query))
189
+ try:
190
+ labels, probs = self.nli_model(es_pairs)
191
+ logging.info(str(labels))
192
+ for i in range(len(labels)):
193
+ confidence = numpy.exp(numpy.max(probs[i]))
194
+ if confidence > 0.4:
195
+ search_results[i]['nli_class'] = labels[i]
196
+ else:
197
+ search_results[i]['nli_class'] = 'neutral'
198
+ search_results[i]['nli_confidence'] = str(confidence)
199
+ except Exception as e:
200
+ logging.warning('Error doing entailment classification')
201
+ logging.warning(str(e))
202
+ try:
203
+ search_results = [s for s in search_results if s['nli_class'] != 'neutral']
204
+ except:
205
+ pass
206
+
207
+ filtered_results = []
208
+ added_urls = set()
209
+ supporting = 0
210
+ refuting = 0
211
+ for r in search_results:
212
+ if r['url'] not in added_urls:
213
+ filtered_results.append(r)
214
+ added_urls.add(r['url'])
215
+ if 'nli_class' in r:
216
+ if r['nli_class'] == 'entailment':
217
+ supporting += 1
218
+ elif r['nli_class'] == 'contradiction':
219
+ refuting += 1
220
+ search_results = filtered_results
221
+
222
+ if supporting > 4 or refuting > 4:
223
+ if supporting / (refuting + 0.001) > 1.7:
224
+ veracity = 'Probably True'
225
+ elif refuting / (supporting + 0.001) > 1.7:
226
+ veracity = 'Probably False'
227
+ else:
228
+ veracity = '? Ambiguous'
229
+ else:
230
+ veracity = 'Not enough evidence'
231
+
232
+ search_results = search_results[:limit]
233
+
234
+ logging.info('done searching')
235
+
236
+ if query_is_question:
237
+ return {'type': 'question',
238
+ 'results': search_results}
239
+ else:
240
+ return {'type': 'statement',
241
+ 'supporting': supporting,
242
+ 'refuting': refuting,
243
+ 'veracity_rating': veracity,
244
+ 'results': search_results}
245
+
246
+ def build_endpoints(self):
247
+ @self.app.route('/search', methods=['POST', 'GET'])
248
+ def search_endpoint():
249
+ query = request.args.get('query').lower()
250
+ limit = request.args.get('limit') or 100
251
+ limit = min(int(limit), 100)
252
+ results = json.dumps(self.search_web_evidence(query, limit=limit), indent=4)
253
+ return results
254
+
255
+ def serve(self, port=12345):
256
+ self.build_endpoints()
257
+ self.app.run(host='0.0.0.0', port=port)
258
+
259
+
260
+
261
+ fchecker = Quin()
262
+
263
+ def predict(query, limit):
264
+ return fchecker.search_web_evidence(query, limit=limit)
265
+
266
+ demo = gr.Interface(
267
+ fn=predict,
268
+ inputs=[
269
+ gr.Textbox(label="Enter Query"), # string input
270
+ gr.Number(label="Enter Limit") # number input
271
+ ],
272
+ outputs=gr.JSON(label="Output")
273
+ )
274
+
275
+
276
+ if __name__ == '__main__':
277
+ # demo.launch()
278
+ fchecker.search_web_evidence('Barrack Obama lived in Indonesia', limit=5)
requirements.txt CHANGED
@@ -8,7 +8,6 @@ lxml==4.6.3
8
  requests==2.22.0
9
  Flask_Cors==3.0.8
10
  gpustat==0.6.0
11
- h5py==2.10.0
12
  packaging==20.9
13
  pandas==1.3.4
14
  spacy==3.2.0
 
8
  requests==2.22.0
9
  Flask_Cors==3.0.8
10
  gpustat==0.6.0
 
11
  packaging==20.9
12
  pandas==1.3.4
13
  spacy==3.2.0