import argparse import gradio as gr import json import logging import nltk import numpy import os import re import requests import torch import urllib import urllib.parse as urlparse from bs4 import BeautifulSoup from flask import request, Flask, jsonify from flask_cors import CORS from huggingface_hub import hf_hub_download, snapshot_download, login from models.nli import NLI from models.qa_ranker import PassageRanker from models.text_encoder import SentenceTransformer from multiprocessing.pool import ThreadPool from scipy import spatial from scipy.special import softmax from time import sleep from urllib.parse import parse_qs from web_search import get_html, WebParser, duckduckgo_search os.environ["CUDA_VISIBLE_DEVICES"] = "0" API_URLS = ['https://letscheck.nus.edu.sg/quin/'] logging.getLogger().setLevel(logging.INFO) def is_question(query): if re.match(r'^(who|when|what|why|which|whose|is|are|was|were|do|does|did|how)', query) or query.endswith('?'): return True pos_tags = nltk.pos_tag(nltk.word_tokenize(query)) for tag in pos_tags: if tag[1].startswith('VB'): return False return True def clear_cache(): for url in API_URLS: print(f"Clearing cache from {url}") r = requests.post(f'{url}/clear_cache?secret=123') print(r.status_code) print(r) class Quin: def __init__(self): nltk.download('punkt') self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp']) self.app = Flask(__name__) CORS(self.app) snapshot_download(repo_id="anabmaulana/qrbert-multitask", local_dir="cache/qrbert-multitask") hf_hub_download(repo_id="anabmaulana/quin-passage-ranker", filename="passage_ranker.state_dict", local_dir="cache/quin-passage-ranker") hf_hub_download(repo_id="anabmaulana/quin-nli", filename="nli.state_dict", local_dir="cache/quin-nli") # torch.cuda.is_available = lambda: False device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.text_embedding_model = SentenceTransformer('cache/qrbert-multitask', device=device, parallel=True) self.passage_ranking_model = PassageRanker( model_path='cache/quin-passage-ranker/passage_ranker.state_dict', gpu=torch.cuda.is_available(), batch_size=16) self.passage_ranking_model.eval() self.nli_model = NLI('cache/quin-nli/nli.state_dict', batch_size=64, device=device, parallel=True) self.nli_model.eval() logging.info('Initialized!') def extract_snippets(self, text, sentences_per_snippet=4): sentences = self.sent_tokenizer.tokenize(text) snippets = [] i = 0 last_index = 0 while i < len(sentences): snippet = ' '.join(sentences[i:i + sentences_per_snippet]) if len(snippet.split(' ')) > 4: snippets.append(snippet) last_index = i + sentences_per_snippet i += sentences_per_snippet if last_index < len(sentences): snippet = ' '.join(sentences[last_index:]) if len(snippet.split(' ')) > 4: snippets.append(snippet) return snippets def search_web_evidence(self, query, limit=10): logging.info('searching the web...') urls = duckduckgo_search(query, pages=1)[:limit] logging.info('downloading {} web pages...'.format(len(urls))) search_results = [] # Move this out of the function, since it is used a few times query_is_question = is_question(query) # Local methods for multithreading def download(url): nonlocal search_results data = get_html(url) soup = BeautifulSoup(data, features='lxml') title = soup.title.string w = WebParser() w.feed(data) new_snippets = sum([self.extract_snippets(b) for b in w.blocks if b.count(' ') > 20], []) new_snippets = [{'snippet': snippet, 'url': url, 'title': title} for snippet in new_snippets] search_results += new_snippets def timeout_download(arg): pool = ThreadPool(1) try: pool.apply_async(download, [arg]).get(timeout=2) except: pass pool.close() pool.join() p = ThreadPool(32) p.map(timeout_download, urls) p.close() p.join() if query_is_question: logging.info('re-ranking...') snippets = [s['snippet'] for s in search_results] qa_pairs = [(query, snippet) for snippet in snippets] _, probs = self.passage_ranking_model(qa_pairs) probs = [softmax(p)[1] for p in probs] filtered_results = [] for i in range(len(search_results)): if probs[i] > 0.35: search_results[i]['score'] = str(probs[i]) filtered_results.append(search_results[i]) search_results = filtered_results search_results = sorted(search_results, key=lambda x: float(x['score']), reverse=True) # Split into a different function below # highlight_results(self, search_results) # highlight most relevant sentences logging.info('highlighting...') results_sentences = [] sentences_texts = [] sentences_vectors = {} for i, r in enumerate(search_results): sentences = self.sent_tokenizer.tokenize(r['snippet']) sentences = [s for s in sentences if len(s.split(' ')) > 4] sentences_texts.extend(sentences) results_sentences.append(sentences) vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128) for i, v in enumerate(vectors): sentences_vectors[sentences_texts[i]] = v query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0] for i, sentences in enumerate(results_sentences): best_sentences = set() evidence_sentences = [] for sentence in sentences: sentence_vector = sentences_vectors[sentence] score = 1 - spatial.distance.cosine(query_vector, sentence_vector) if score > 0.91: best_sentences.add(sentence) evidence_sentences.append(sentence) if len(evidence_sentences) > 0: search_results[i]['evidence'] = ' '.join(evidence_sentences) search_results[i]['snippet'] = \ ' '.join([s if s not in best_sentences else '{}'.format(s) for s in sentences]) search_results = [s for s in search_results if 'evidence' in s] # Split into a different function below # entailment_classification(self, search_results) if not query_is_question: logging.info('entailment classification...') es_pairs = [] for result in search_results: evidence = result['evidence'] es_pairs.append((evidence, query)) try: labels, probs = self.nli_model(es_pairs) logging.info(str(labels)) for i in range(len(labels)): confidence = numpy.exp(numpy.max(probs[i])) if confidence > 0.4: search_results[i]['nli_class'] = labels[i] else: search_results[i]['nli_class'] = 'neutral' search_results[i]['nli_confidence'] = str(confidence) except Exception as e: logging.warning('Error doing entailment classification') logging.warning(str(e)) try: search_results = [s for s in search_results if s['nli_class'] != 'neutral'] except: pass filtered_results = [] added_urls = set() supporting = 0 refuting = 0 for r in search_results: if r['url'] not in added_urls: filtered_results.append(r) added_urls.add(r['url']) if 'nli_class' in r: if r['nli_class'] == 'entailment': supporting += 1 elif r['nli_class'] == 'contradiction': refuting += 1 search_results = filtered_results if supporting > 4 or refuting > 4: if supporting / (refuting + 0.001) > 1.7: veracity = 'Probably True' elif refuting / (supporting + 0.001) > 1.7: veracity = 'Probably False' else: veracity = '? Ambiguous' else: veracity = 'Not enough evidence' search_results = search_results[:limit] logging.info('done searching') if query_is_question: return {'type': 'question', 'results': search_results} else: return {'type': 'statement', 'supporting': supporting, 'refuting': refuting, 'veracity_rating': veracity, 'results': search_results} def build_endpoints(self): @self.app.route('/search', methods=['POST', 'GET']) def search_endpoint(): query = request.args.get('query').lower() limit = request.args.get('limit') or 100 limit = min(int(limit), 100) results = json.dumps(self.search_web_evidence(query, limit=limit), indent=4) return results def serve(self, port=12345): self.build_endpoints() self.app.run(host='0.0.0.0', port=port) fchecker = Quin() def predict(query, limit): return fchecker.search_web_evidence(query, limit=limit) demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Enter Query"), # string input gr.Number(label="Enter Limit") # number input ], outputs=gr.JSON(label="Output") ) if __name__ == '__main__': # demo.launch() fchecker.search_web_evidence('Barrack Obama lived in Indonesia', limit=5)