Spaces:
Sleeping
Sleeping
Commit
·
071a1df
1
Parent(s):
8eb533e
Update
Browse files- models/vector_index.py +0 -1
- quin_web.py +278 -0
- requirements.txt +0 -1
models/vector_index.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import logging
|
2 |
import math
|
3 |
import pickle
|
4 |
-
import h5py
|
5 |
import faiss
|
6 |
|
7 |
import numpy as np
|
|
|
1 |
import logging
|
2 |
import math
|
3 |
import pickle
|
|
|
4 |
import faiss
|
5 |
|
6 |
import numpy as np
|
quin_web.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gradio as gr
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import nltk
|
6 |
+
import numpy
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import requests
|
10 |
+
import torch
|
11 |
+
import urllib
|
12 |
+
import urllib.parse as urlparse
|
13 |
+
|
14 |
+
from bs4 import BeautifulSoup
|
15 |
+
from flask import request, Flask, jsonify
|
16 |
+
from flask_cors import CORS
|
17 |
+
from huggingface_hub import hf_hub_download, snapshot_download, login
|
18 |
+
from models.nli import NLI
|
19 |
+
from models.qa_ranker import PassageRanker
|
20 |
+
from models.text_encoder import SentenceTransformer
|
21 |
+
from multiprocessing.pool import ThreadPool
|
22 |
+
from scipy import spatial
|
23 |
+
from scipy.special import softmax
|
24 |
+
from time import sleep
|
25 |
+
from urllib.parse import parse_qs
|
26 |
+
from web_search import get_html, WebParser, duckduckgo_search
|
27 |
+
|
28 |
+
|
29 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
30 |
+
API_URLS = ['https://letscheck.nus.edu.sg/quin/']
|
31 |
+
logging.getLogger().setLevel(logging.INFO)
|
32 |
+
|
33 |
+
|
34 |
+
def is_question(query):
|
35 |
+
if re.match(r'^(who|when|what|why|which|whose|is|are|was|were|do|does|did|how)', query) or query.endswith('?'):
|
36 |
+
return True
|
37 |
+
pos_tags = nltk.pos_tag(nltk.word_tokenize(query))
|
38 |
+
for tag in pos_tags:
|
39 |
+
if tag[1].startswith('VB'):
|
40 |
+
return False
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
def clear_cache():
|
45 |
+
for url in API_URLS:
|
46 |
+
print(f"Clearing cache from {url}")
|
47 |
+
r = requests.post(f'{url}/clear_cache?secret=123')
|
48 |
+
print(r.status_code)
|
49 |
+
print(r)
|
50 |
+
|
51 |
+
|
52 |
+
class Quin:
|
53 |
+
def __init__(self):
|
54 |
+
nltk.download('punkt')
|
55 |
+
self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
|
56 |
+
self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp'])
|
57 |
+
self.app = Flask(__name__)
|
58 |
+
CORS(self.app)
|
59 |
+
|
60 |
+
snapshot_download(repo_id="anabmaulana/qrbert-multitask", local_dir="cache/qrbert-multitask")
|
61 |
+
hf_hub_download(repo_id="anabmaulana/quin-passage-ranker", filename="passage_ranker.state_dict", local_dir="cache/quin-passage-ranker")
|
62 |
+
hf_hub_download(repo_id="anabmaulana/quin-nli", filename="nli.state_dict", local_dir="cache/quin-nli")
|
63 |
+
|
64 |
+
# torch.cuda.is_available = lambda: False
|
65 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
66 |
+
|
67 |
+
self.text_embedding_model = SentenceTransformer('cache/qrbert-multitask',
|
68 |
+
device=device,
|
69 |
+
parallel=True)
|
70 |
+
self.passage_ranking_model = PassageRanker(
|
71 |
+
model_path='cache/quin-passage-ranker/passage_ranker.state_dict',
|
72 |
+
gpu=torch.cuda.is_available(),
|
73 |
+
batch_size=16)
|
74 |
+
self.passage_ranking_model.eval()
|
75 |
+
self.nli_model = NLI('cache/quin-nli/nli.state_dict',
|
76 |
+
batch_size=64,
|
77 |
+
device=device,
|
78 |
+
parallel=True)
|
79 |
+
self.nli_model.eval()
|
80 |
+
logging.info('Initialized!')
|
81 |
+
|
82 |
+
def extract_snippets(self, text, sentences_per_snippet=4):
|
83 |
+
sentences = self.sent_tokenizer.tokenize(text)
|
84 |
+
snippets = []
|
85 |
+
i = 0
|
86 |
+
last_index = 0
|
87 |
+
while i < len(sentences):
|
88 |
+
snippet = ' '.join(sentences[i:i + sentences_per_snippet])
|
89 |
+
if len(snippet.split(' ')) > 4:
|
90 |
+
snippets.append(snippet)
|
91 |
+
last_index = i + sentences_per_snippet
|
92 |
+
i += sentences_per_snippet
|
93 |
+
if last_index < len(sentences):
|
94 |
+
snippet = ' '.join(sentences[last_index:])
|
95 |
+
if len(snippet.split(' ')) > 4:
|
96 |
+
snippets.append(snippet)
|
97 |
+
return snippets
|
98 |
+
|
99 |
+
def search_web_evidence(self, query, limit=10):
|
100 |
+
logging.info('searching the web...')
|
101 |
+
urls = duckduckgo_search(query, pages=1)[:limit]
|
102 |
+
logging.info('downloading {} web pages...'.format(len(urls)))
|
103 |
+
search_results = []
|
104 |
+
# Move this out of the function, since it is used a few times
|
105 |
+
query_is_question = is_question(query)
|
106 |
+
|
107 |
+
# Local methods for multithreading
|
108 |
+
def download(url):
|
109 |
+
nonlocal search_results
|
110 |
+
data = get_html(url)
|
111 |
+
soup = BeautifulSoup(data, features='lxml')
|
112 |
+
title = soup.title.string
|
113 |
+
w = WebParser()
|
114 |
+
w.feed(data)
|
115 |
+
new_snippets = sum([self.extract_snippets(b) for b in w.blocks if b.count(' ') > 20], [])
|
116 |
+
new_snippets = [{'snippet': snippet, 'url': url, 'title': title} for snippet in new_snippets]
|
117 |
+
search_results += new_snippets
|
118 |
+
|
119 |
+
def timeout_download(arg):
|
120 |
+
pool = ThreadPool(1)
|
121 |
+
try:
|
122 |
+
pool.apply_async(download, [arg]).get(timeout=2)
|
123 |
+
except:
|
124 |
+
pass
|
125 |
+
pool.close()
|
126 |
+
pool.join()
|
127 |
+
|
128 |
+
p = ThreadPool(32)
|
129 |
+
p.map(timeout_download, urls)
|
130 |
+
p.close()
|
131 |
+
p.join()
|
132 |
+
|
133 |
+
if query_is_question:
|
134 |
+
logging.info('re-ranking...')
|
135 |
+
snippets = [s['snippet'] for s in search_results]
|
136 |
+
qa_pairs = [(query, snippet) for snippet in snippets]
|
137 |
+
_, probs = self.passage_ranking_model(qa_pairs)
|
138 |
+
probs = [softmax(p)[1] for p in probs]
|
139 |
+
filtered_results = []
|
140 |
+
for i in range(len(search_results)):
|
141 |
+
if probs[i] > 0.35:
|
142 |
+
search_results[i]['score'] = str(probs[i])
|
143 |
+
filtered_results.append(search_results[i])
|
144 |
+
search_results = filtered_results
|
145 |
+
search_results = sorted(search_results, key=lambda x: float(x['score']), reverse=True)
|
146 |
+
|
147 |
+
# Split into a different function below
|
148 |
+
# highlight_results(self, search_results)
|
149 |
+
# highlight most relevant sentences
|
150 |
+
logging.info('highlighting...')
|
151 |
+
results_sentences = []
|
152 |
+
sentences_texts = []
|
153 |
+
sentences_vectors = {}
|
154 |
+
for i, r in enumerate(search_results):
|
155 |
+
sentences = self.sent_tokenizer.tokenize(r['snippet'])
|
156 |
+
sentences = [s for s in sentences if len(s.split(' ')) > 4]
|
157 |
+
sentences_texts.extend(sentences)
|
158 |
+
results_sentences.append(sentences)
|
159 |
+
|
160 |
+
vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128)
|
161 |
+
for i, v in enumerate(vectors):
|
162 |
+
sentences_vectors[sentences_texts[i]] = v
|
163 |
+
|
164 |
+
query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0]
|
165 |
+
for i, sentences in enumerate(results_sentences):
|
166 |
+
best_sentences = set()
|
167 |
+
evidence_sentences = []
|
168 |
+
for sentence in sentences:
|
169 |
+
sentence_vector = sentences_vectors[sentence]
|
170 |
+
score = 1 - spatial.distance.cosine(query_vector, sentence_vector)
|
171 |
+
if score > 0.91:
|
172 |
+
best_sentences.add(sentence)
|
173 |
+
evidence_sentences.append(sentence)
|
174 |
+
if len(evidence_sentences) > 0:
|
175 |
+
search_results[i]['evidence'] = ' '.join(evidence_sentences)
|
176 |
+
search_results[i]['snippet'] = \
|
177 |
+
' '.join([s if s not in best_sentences else '<b>{}</b>'.format(s) for s in sentences])
|
178 |
+
|
179 |
+
search_results = [s for s in search_results if 'evidence' in s]
|
180 |
+
|
181 |
+
# Split into a different function below
|
182 |
+
# entailment_classification(self, search_results)
|
183 |
+
if not query_is_question:
|
184 |
+
logging.info('entailment classification...')
|
185 |
+
es_pairs = []
|
186 |
+
for result in search_results:
|
187 |
+
evidence = result['evidence']
|
188 |
+
es_pairs.append((evidence, query))
|
189 |
+
try:
|
190 |
+
labels, probs = self.nli_model(es_pairs)
|
191 |
+
logging.info(str(labels))
|
192 |
+
for i in range(len(labels)):
|
193 |
+
confidence = numpy.exp(numpy.max(probs[i]))
|
194 |
+
if confidence > 0.4:
|
195 |
+
search_results[i]['nli_class'] = labels[i]
|
196 |
+
else:
|
197 |
+
search_results[i]['nli_class'] = 'neutral'
|
198 |
+
search_results[i]['nli_confidence'] = str(confidence)
|
199 |
+
except Exception as e:
|
200 |
+
logging.warning('Error doing entailment classification')
|
201 |
+
logging.warning(str(e))
|
202 |
+
try:
|
203 |
+
search_results = [s for s in search_results if s['nli_class'] != 'neutral']
|
204 |
+
except:
|
205 |
+
pass
|
206 |
+
|
207 |
+
filtered_results = []
|
208 |
+
added_urls = set()
|
209 |
+
supporting = 0
|
210 |
+
refuting = 0
|
211 |
+
for r in search_results:
|
212 |
+
if r['url'] not in added_urls:
|
213 |
+
filtered_results.append(r)
|
214 |
+
added_urls.add(r['url'])
|
215 |
+
if 'nli_class' in r:
|
216 |
+
if r['nli_class'] == 'entailment':
|
217 |
+
supporting += 1
|
218 |
+
elif r['nli_class'] == 'contradiction':
|
219 |
+
refuting += 1
|
220 |
+
search_results = filtered_results
|
221 |
+
|
222 |
+
if supporting > 4 or refuting > 4:
|
223 |
+
if supporting / (refuting + 0.001) > 1.7:
|
224 |
+
veracity = 'Probably True'
|
225 |
+
elif refuting / (supporting + 0.001) > 1.7:
|
226 |
+
veracity = 'Probably False'
|
227 |
+
else:
|
228 |
+
veracity = '? Ambiguous'
|
229 |
+
else:
|
230 |
+
veracity = 'Not enough evidence'
|
231 |
+
|
232 |
+
search_results = search_results[:limit]
|
233 |
+
|
234 |
+
logging.info('done searching')
|
235 |
+
|
236 |
+
if query_is_question:
|
237 |
+
return {'type': 'question',
|
238 |
+
'results': search_results}
|
239 |
+
else:
|
240 |
+
return {'type': 'statement',
|
241 |
+
'supporting': supporting,
|
242 |
+
'refuting': refuting,
|
243 |
+
'veracity_rating': veracity,
|
244 |
+
'results': search_results}
|
245 |
+
|
246 |
+
def build_endpoints(self):
|
247 |
+
@self.app.route('/search', methods=['POST', 'GET'])
|
248 |
+
def search_endpoint():
|
249 |
+
query = request.args.get('query').lower()
|
250 |
+
limit = request.args.get('limit') or 100
|
251 |
+
limit = min(int(limit), 100)
|
252 |
+
results = json.dumps(self.search_web_evidence(query, limit=limit), indent=4)
|
253 |
+
return results
|
254 |
+
|
255 |
+
def serve(self, port=12345):
|
256 |
+
self.build_endpoints()
|
257 |
+
self.app.run(host='0.0.0.0', port=port)
|
258 |
+
|
259 |
+
|
260 |
+
|
261 |
+
fchecker = Quin()
|
262 |
+
|
263 |
+
def predict(query, limit):
|
264 |
+
return fchecker.search_web_evidence(query, limit=limit)
|
265 |
+
|
266 |
+
demo = gr.Interface(
|
267 |
+
fn=predict,
|
268 |
+
inputs=[
|
269 |
+
gr.Textbox(label="Enter Query"), # string input
|
270 |
+
gr.Number(label="Enter Limit") # number input
|
271 |
+
],
|
272 |
+
outputs=gr.JSON(label="Output")
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
if __name__ == '__main__':
|
277 |
+
# demo.launch()
|
278 |
+
fchecker.search_web_evidence('Barrack Obama lived in Indonesia', limit=5)
|
requirements.txt
CHANGED
@@ -8,7 +8,6 @@ lxml==4.6.3
|
|
8 |
requests==2.22.0
|
9 |
Flask_Cors==3.0.8
|
10 |
gpustat==0.6.0
|
11 |
-
h5py==2.10.0
|
12 |
packaging==20.9
|
13 |
pandas==1.3.4
|
14 |
spacy==3.2.0
|
|
|
8 |
requests==2.22.0
|
9 |
Flask_Cors==3.0.8
|
10 |
gpustat==0.6.0
|
|
|
11 |
packaging==20.9
|
12 |
pandas==1.3.4
|
13 |
spacy==3.2.0
|